diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 894faad97318..9865e69c0bd8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -721,6 +721,10 @@ title: Qwen2MoE - local: model_doc/qwen3 title: Qwen3 + - local: model_doc/qwen3_5 + title: Qwen3.5 + - local: model_doc/qwen3_5_moe + title: Qwen3.5 Moe - local: model_doc/qwen3_moe title: Qwen3MoE - local: model_doc/qwen3_next diff --git a/docs/source/en/model_doc/qwen3_5.md b/docs/source/en/model_doc/qwen3_5.md new file mode 100644 index 000000000000..24111e180295 --- /dev/null +++ b/docs/source/en/model_doc/qwen3_5.md @@ -0,0 +1,76 @@ + +*This model was released on 2026-01-01 and added to Hugging Face Transformers on 2026-02-09.* + +
+
+PyTorch +FlashAttention +SDPA
+
+ +# Qwen3.5 + +[Qwen3.5](https://huggingface.co/papers/2502.13923) TODO @shuaibai @bozheng + +Model usage + + + + +```py +TODO +``` + + + + +## Qwen3_5Config + +[[autodoc]] Qwen3_5Config + +## Qwen3_5TextConfig + +[[autodoc]] Qwen3_5TextConfig + +## Qwen3_5VisionModel + +[[autodoc]] Qwen3_5VisionModel + - forward + +## Qwen3_5TextModel + +[[autodoc]] Qwen3_5TextModel + - forward + +## Qwen3_5Model + +[[autodoc]] Qwen3_5Model + - forward + +## Qwen3_5ForCausalLM + +[[autodoc]] Qwen3_5ForCausalLM + - forward + +## Qwen3_5ForConditionalGeneration + +[[autodoc]] Qwen3_5ForConditionalGeneration + - forward + +## Qwen3_5Tokenizer + +[[autodoc]] Qwen3_5Tokenizer diff --git a/docs/source/en/model_doc/qwen3_5_moe.md b/docs/source/en/model_doc/qwen3_5_moe.md new file mode 100644 index 000000000000..768839d5276c --- /dev/null +++ b/docs/source/en/model_doc/qwen3_5_moe.md @@ -0,0 +1,72 @@ + +*This model was released on 2026-01-01 and added to Hugging Face Transformers on 2026-02-09.* + +
+
+PyTorch +FlashAttention +SDPA
+
+ +# Qwen3.5 Moe + +[Qwen3.5 Moe](https://huggingface.co/papers/2502.13923) TODO @shuaibai @bozheng + +Model usage + + + + +```py +TODO +``` + + + + +## Qwen3_5MoeConfig + +[[autodoc]] Qwen3_5MoeConfig + +## Qwen3_5MoeTextConfig + +[[autodoc]] Qwen3_5MoeTextConfig + +## Qwen3_5MoeVisionModel + +[[autodoc]] Qwen3_5MoeVisionModel + - forward + +## Qwen3_5MoeTextModel + +[[autodoc]] Qwen3_5MoeTextModel + - forward + +## Qwen3_5MoeModel + +[[autodoc]] Qwen3_5MoeModel + - forward + +## Qwen3_5MoeForCausalLM + +[[autodoc]] Qwen3_5MoeForCausalLM + - forward + +## Qwen3_5MoeForConditionalGeneration + +[[autodoc]] Qwen3_5MoeForConditionalGeneration + - forward diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py old mode 100644 new mode 100755 index bba6a4b6ea34..e06c37a0a30a --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -59,6 +59,7 @@ "qwen3_omni_moe": "qwen2_moe", "qwen3_omni_moe_thinker": "qwen2_moe", "qwen3_next": "qwen2_moe", + "qwen3_5_moe": "qwen2_moe", "hunyuan_v1_moe": "qwen2_moe", "flex_olmo": "qwen2_moe", "olmoe": "qwen2_moe", @@ -70,6 +71,9 @@ def _build_checkpoint_conversion_mapping(): mapping = { + "qwen3_5_text": [ + WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), + ], "t5gemma2": [ WeightRenaming(r"(?>> from transformers import Qwen3_5TextModel, Qwen3_5TextConfig + + >>> # Initializing a Qwen3.5 style configuration + >>> configuration = Qwen3_5TextConfig() + + >>> # Initializing a model from the Qwen3.5-9B style configuration + >>> model = Qwen3_5TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_5_text" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + base_config_key = "text_config" + + def __init__( + self, + vocab_size=248320, + hidden_size=4096, + intermediate_size=12288, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + layer_types=None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"} + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + self.rope_parameters = rope_parameters + kwargs.setdefault("partial_rotary_factor", 0.25) # assign default for BC + + self.layer_types = layer_types + if self.layer_types is None: + interval_pattern = kwargs.get("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + super().__init__(**kwargs) + + +class Qwen3_5VisionConfig(PreTrainedConfig): + model_type = "qwen3_5" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + + +class Qwen3_5Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a + Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 248056): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 248057): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 248053): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 248054): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config + + >>> # Initializing a Qwen3.5 style configuration + >>> configuration = Qwen3_5Config() + + >>> # Initializing a model from the Qwen3.5-9B style configuration + >>> model = Qwen3_5ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_5" + sub_configs = {"vision_config": Qwen3_5VisionConfig, "text_config": Qwen3_5TextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.tie_word_embeddings = tie_word_embeddings + super().__init__(**kwargs) + + +__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig"] diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py new file mode 100644 index 000000000000..6aecdb06570e --- /dev/null +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -0,0 +1,2194 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_5/modular_qwen3_5.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_5.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available +from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_linear_attention_available(): + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +else: + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None + FusedRMSNormGated = None + +logger = logging.get_logger(__name__) + + +class Qwen3_5DynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention + cache (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: Qwen3_5Config): + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + + # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx].device + beam_idx = beam_idx.to(device) + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + return self.conv_states[self.last_linear_layer] is not None + + +class Qwen3_5VisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3_5TextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3_5TextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) + + @staticmethod + def compute_default_rope_parameters( + config: Qwen3_5TextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3_5 has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + +class Qwen3_5RMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +is_fast_path_available = all( + (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +) + + +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class Qwen3_5GatedDeltaNet(nn.Module): + def __init__(self, config: Qwen3_5Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class Qwen3_5Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output * torch.sigmoid(gate) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3_5MLP(nn.Module): + def __init__(self, config: Qwen3_5Config, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3_5RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Qwen3_5 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Qwen3_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5GatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3_5Attention(config, layer_idx) + self.mlp = Qwen3_5MLP(config, config.intermediate_size) + self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3_5PreTrainedModel(PreTrainedModel): + config: Qwen3_5Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] + _can_record_outputs = { + "hidden_states": Qwen3_5DecoderLayer, + "attentions": Qwen3_5Attention, + } + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Qwen3_5GatedDeltaNet): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3_5RMSNorm): + init.zeros_(module.weight) + elif isinstance(module, Qwen3_5VisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +class Qwen3_5VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3_5VisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3_5VisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3_5VisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen3_5VisionAttention(nn.Module): + def __init__(self, config: Qwen3_5VisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3_5VisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3_5VisionAttention(config=config) + self.mlp = Qwen3_5VisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3_5VisionModel(Qwen3_5PreTrainedModel): + config: Qwen3_5VisionConfig + _no_split_modules = ["Qwen3_5VisionBlock"] + _can_record_outputs = { + "hidden_states": Qwen3_5VisionBlock, + "attentions": Qwen3_5VisionAttention, + } + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3_5VisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3_5VisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3_5VisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.gradient_checkpointing = False + + self.post_init() + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + @check_model_inputs + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + + merged_hidden_states = self.merger(hidden_states) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3_5ModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +class Qwen3_5TextModel(Qwen3_5PreTrainedModel): + def __init__(self, config: Qwen3_5TextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [Qwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +@auto_docstring +class Qwen3_5Model(Qwen3_5PreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3_5Config + _no_split_modules = ["Qwen3_5TextDecoderLayer", "Qwen3_5VisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3_5VisionModel._from_config(config.vision_config) + self.language_model = Qwen3_5TextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3_5 use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to separate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) + image_embeds = vision_output.pooler_output + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + vision_output.pooler_output = image_embeds + + return vision_output + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor | None = None, + video_features: torch.FloatTensor | None = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None: + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None: + torch_compilable_check( + inputs_embeds[special_video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", + ) + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3_5ModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + if self.rope_deltas is None or past_key_values_length == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: Qwen3_5TextConfig + _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM + + >>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3_5-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3_5-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3_5 causal language model (or autoregressive) outputs. + """ +) +class Qwen3_5CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +class Qwen3_5ForConditionalGeneration(Qwen3_5PreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3_5Config + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + return self.model.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3_5CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + + ```python + >>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration + + >>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "Describe the image."}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` + """ + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3_5CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # Qwen3_5 position_ids are prepared with rope_deltas + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None: + vision_positions, rope_deltas = self.model.get_rope_index( + model_inputs.get("input_ids", None), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.model.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + elif "position_ids" in model_inputs: + batch_size, seq_length = model_inputs["position_ids"].shape + device = model_inputs["position_ids"].device + position_ids = torch.arange(seq_length, device=device) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + delta = cache_position[0] + self.model.rope_deltas + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + vision_positions = position_ids + delta.expand_as(position_ids) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + text_positions = model_inputs["position_ids"][None, ...] + model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0) + + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Qwen3_5 use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but Qwen3_5 append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3_5VisionModel", + "Qwen3_5TextModel", + "Qwen3_5Model", + "Qwen3_5ForCausalLM", + "Qwen3_5ForConditionalGeneration", + "Qwen3_5PreTrainedModel", +] diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py new file mode 100644 index 000000000000..9f61318c5740 --- /dev/null +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -0,0 +1,841 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3.5 model.""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs +from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM +from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig +from ..qwen3_next.modeling_qwen3_next import ( + Qwen3NextAttention, + Qwen3NextDynamicCache, + Qwen3NextGatedDeltaNet, + Qwen3NextMLP, + Qwen3NextModel, + Qwen3NextPreTrainedModel, + Qwen3NextRMSNorm, + apply_mask_to_padding_states, +) +from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLModelOutputWithPast, + Qwen3VLTextRotaryEmbedding, + Qwen3VLVisionModel, + Qwen3VLVisionRotaryEmbedding, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3_5TextConfig(Qwen3NextConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5TextModel`]. It is used to instantiate a + Qwen3_5 model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 248320): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*): + End of stream token id. + + ```python + >>> from transformers import Qwen3_5TextModel, Qwen3_5TextConfig + + >>> # Initializing a Qwen3.5 style configuration + >>> configuration = Qwen3_5TextConfig() + + >>> # Initializing a model from the Qwen3.5-9B style configuration + >>> model = Qwen3_5TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_5_text" + base_config_key = "text_config" + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=248320, + hidden_size=4096, + intermediate_size=12288, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + layer_types=None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"} + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + del self.decoder_sparse_step + del self.norm_topk_prob + del self.mlp_only_layers + del self.moe_intermediate_size + del self.shared_expert_intermediate_size + del self.num_experts_per_tok + del self.num_experts + del self.output_router_logits + del self.router_aux_loss_coef + + +class Qwen3_5VisionConfig(Qwen3VLVisionConfig): + model_type = "qwen3_5" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + del self.deepstack_visual_indexes + + +class Qwen3_5Config(Qwen3VLConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a + Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3.5-9B-Instruct [Qwen/Qwen3.5-9B-Instruct](https://huggingface.co/Qwen/Qwen3.5-9B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 248056): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 248057): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 248053): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 248054): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config + + >>> # Initializing a Qwen3.5 style configuration + >>> configuration = Qwen3_5Config() + + >>> # Initializing a model from the Qwen3.5-9B style configuration + >>> model = Qwen3_5ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_5" + sub_configs = {"vision_config": Qwen3_5VisionConfig, "text_config": Qwen3_5TextConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__( + text_config=text_config, + vision_config=vision_config, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Qwen3_5DynamicCache(Qwen3NextDynamicCache): + pass + + +class Qwen3_5VisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding): + pass + + +class Qwen3_5TextRotaryEmbedding(Qwen3VLTextRotaryEmbedding): + def __init__(self, config: Qwen3_5TextConfig, device=None): + super().__init__() + self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) + + def compute_default_rope_parameters( + config: Qwen3_5TextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): + def __init__(self, config: Qwen3_5Config, layer_idx: int): + super().__init__(config, layer_idx) + + del projection_size_qkvz # noqa: F821 + del projection_size_ba # noqa: F821 + del self.in_proj_qkvz + del self.in_proj_ba + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + def fix_query_key_value_ordering(self): + raise AttributeError("Not needed for Qwen3.5 Series") + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3_5Attention(Qwen3NextAttention): + pass + + +class Qwen3_5MLP(Qwen3NextMLP): + def __init__(self, config: Qwen3_5Config, intermediate_size: int): + super().__init__(config, intermediate_size) + self.intermediate_size = intermediate_size + + +class Qwen3_5RMSNorm(Qwen3NextRMSNorm): + pass + + +class Qwen3_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5GatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3_5Attention(config, layer_idx) + self.mlp = Qwen3_5MLP(config, config.intermediate_size) + self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3_5PreTrainedModel(Qwen3NextPreTrainedModel): + config: Qwen3_5Config + _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"] + _can_record_outputs = { + "hidden_states": Qwen3_5DecoderLayer, + "attentions": Qwen3_5Attention, + } + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, Qwen3_5GatedDeltaNet): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3_5RMSNorm): + init.zeros_(module.weight) + elif isinstance(module, Qwen3_5VisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +class Qwen3_5VisionModel(Qwen3VLVisionModel): + config: Qwen3_5VisionConfig + _no_split_modules = ["Qwen3_5VisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + del self.deepstack_visual_indexes + del self.deepstack_merger_list + + @check_model_inputs + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + + merged_hidden_states = self.merger(hidden_states) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +class Qwen3_5ModelOutputWithPast(Qwen3VLModelOutputWithPast): + pass + + +class Qwen3_5TextModel(Qwen3NextModel): + def __init__(self, config: Qwen3_5TextConfig): + super().__init__(config) + self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class Qwen3_5Model(Qwen3VLModel): + def get_video_features( + self, + **super_kwargs, + ) -> tuple | BaseModelOutputWithPooling: + # Same implementation as for images + return super().get_video_features(**super_kwargs) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + pixel_values = pixel_values.type(self.visual.dtype) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) + image_embeds = vision_output.pooler_output + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + vision_output.pooler_output = image_embeds + + return vision_output + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3_5ModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + if self.rope_deltas is None or past_key_values_length == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen3_5ForCausalLM(Qwen3ForCausalLM): + config: Qwen3_5TextConfig + _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5TextModel(config) + + +class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration): + def get_video_features( + self, + **super_kwargs, + ) -> tuple | BaseModelOutputWithPooling: + return super().get_video_features(**super_kwargs) + + def get_image_features( + self, + **super_kwargs, + ) -> tuple | BaseModelOutputWithPooling: + return super().get_image_features(**super_kwargs) + + +__all__ = [ + "Qwen3_5Config", + "Qwen3_5TextConfig", + "Qwen3_5VisionModel", + "Qwen3_5TextModel", + "Qwen3_5Model", + "Qwen3_5ForCausalLM", + "Qwen3_5ForConditionalGeneration", + "Qwen3_5PreTrainedModel", +] diff --git a/src/transformers/models/qwen3_5/tokenization_qwen3_5.py b/src/transformers/models/qwen3_5/tokenization_qwen3_5.py new file mode 100644 index 000000000000..28049a4deb4c --- /dev/null +++ b/src/transformers/models/qwen3_5/tokenization_qwen3_5.py @@ -0,0 +1,94 @@ +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen3.5.""" + +from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers +from tokenizers.models import BPE + +from ...tokenization_utils_tokenizers import TokenizersBackend +from ...utils import logging + + +logger = logging.get_logger(__name__) + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+|\p{N}| ?[^\s\p{L}\p{M}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +class Qwen3_5Tokenizer(TokenizersBackend): + model_input_names = ["input_ids", "attention_mask"] + model = BPE + + def __init__( + self, + vocab: str | dict[str, int] | None = None, + merges: str | list[str] | None = None, + vocab_file=None, + merges_file=None, + unk_token: str = "<|endoftext|>", + bos_token=None, + eos_token: str = "<|endoftext|>", + pad_token: str = "<|endoftext|>", + add_prefix_space=None, + **kwargs, + ): + self.add_prefix_space = add_prefix_space if add_prefix_space is not None else False + self._vocab = ( + vocab + if vocab is not None + else { + "<|endoftext|>": 0, + } + ) + self._merges = merges or [] + self._tokenizer = Tokenizer( + BPE( + vocab=self._vocab, + merges=self._merges, + dropout=None, + unk_token=None, + continuing_subword_prefix="", + end_of_word_suffix="", + fuse_unk=False, + byte_fallback=False, + ) + ) + self._tokenizer.decoder = decoders.ByteLevel() + self._tokenizer.normalizer = normalizers.NFC() + self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split( + Regex(PRETOKENIZE_REGEX), + behavior="isolated", + invert=False, + ), + pre_tokenizers.ByteLevel( + add_prefix_space=self.add_prefix_space, + use_regex=False, + ), + ] + ) + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + +__all__ = ["Qwen3_5Tokenizer"] diff --git a/src/transformers/models/qwen3_5_moe/__init__.py b/src/transformers/models/qwen3_5_moe/__init__.py new file mode 100644 index 000000000000..fabf00e524e6 --- /dev/null +++ b/src/transformers/models/qwen3_5_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_5_moe import * + from .modeling_qwen3_5_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py new file mode 100644 index 000000000000..13ee7ba77a42 --- /dev/null +++ b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py @@ -0,0 +1,330 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_5_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PreTrainedConfig, layer_type_validation +from ...modeling_rope_utils import RopeParameters + + +class Qwen3_5MoeTextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5MoeTextModel`]. It is used to instantiate a + Qwen3.5-MoE model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 248320): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 256): + Number of routed experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*): + End of stream token id. + + ```python + >>> from transformers import Qwen3_5MoeTextModel, Qwen3_5MoeTextConfig + + >>> # Initializing a Qwen3.5-MoE style configuration + >>> configuration = Qwen3_5MoeTextConfig() + + >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration + >>> model = Qwen3_5MoeTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_5_moe_text" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.shared_expert.gate_proj": "colwise", + "layers.*.mlp.shared_expert.up_proj": "colwise", + "layers.*.mlp.shared_expert.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + base_config_key = "text_config" + + def __init__( + self, + vocab_size=248320, + hidden_size=2048, + num_hidden_layers=40, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=8, + num_experts=256, + output_router_logits=False, + router_aux_loss_coef=0.001, + layer_types=None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"} + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + self.rope_parameters = rope_parameters + kwargs.setdefault("partial_rotary_factor", 0.25) # assign default for BC + + self.layer_types = layer_types + if self.layer_types is None: + interval_pattern = kwargs.get("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + super().__init__(**kwargs) + + +class Qwen3_5MoeVisionConfig(PreTrainedConfig): + model_type = "qwen3_5_moe" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + + +class Qwen3_5MoeConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5MoeModel`]. It is used to instantiate a + Qwen3.5-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 248056): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 248057): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 248053): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 248054): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeConfig + + >>> # Initializing a Qwen3.5-MoE style configuration + >>> configuration = Qwen3_5MoeConfig() + + >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration + >>> model = Qwen3_5MoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_5_moe" + sub_configs = {"vision_config": Qwen3_5MoeVisionConfig, "text_config": Qwen3_5MoeTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.tie_word_embeddings = tie_word_embeddings + super().__init__(**kwargs) + + +__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"] diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py new file mode 100644 index 000000000000..09f64ce756c6 --- /dev/null +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -0,0 +1,2414 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_5_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + ModelOutput, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast +from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available +from ...utils.output_capturing import OutputRecorder +from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig + + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_linear_attention_available(): + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +else: + chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None + FusedRMSNormGated = None + +logger = logging.get_logger(__name__) + + +class Qwen3_5MoeVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3_5MoeTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3_5MoeTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) + + @staticmethod + def compute_default_rope_parameters( + config: Qwen3_5MoeTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3_5Moe has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + +class Qwen3_5MoeDynamicCache: + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention + cache (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + is_compileable = False + + def __init__(self, config: Qwen3_5MoeConfig): + super().__init__() + self.layer_types = config.layer_types + self.transformer_layers = [ + i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" + ] + self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") + + # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx].device + beam_idx = beam_idx.to(device) + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + return self.conv_states[self.last_linear_layer] is not None + + +class Qwen3_5MoeRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + return hidden_states.to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +is_fast_path_available = all( + (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +) + + +def torch_causal_conv1d_update( + hidden_states, + conv_state, + weight, + bias=None, + activation=None, +): + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) + out = F.silu(out[:, :, -seq_len:]) + out = out.to(hidden_states.dtype) + return out + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class Qwen3_5MoeGatedDeltaNet(nn.Module): + def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = ( + Qwen3_5MoeRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + if FusedRMSNormGated is None + else FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + activation=self.activation, + device=torch.cuda.current_device(), + dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), + ) + ) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.causal_conv1d_fn = causal_conv1d_fn + self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of the required library is not installed. Falling back to " + "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Qwen3_5MoeDynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class Qwen3_5MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3_5MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3_5MoeRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output * torch.sigmoid(gate) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3_5MoeMLP(nn.Module): + def __init__(self, config: Qwen3_5MoeConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_experts_implementation +class Qwen3_5MoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Qwen3_5MoeTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + + +class Qwen3_5MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = Qwen3_5MoeTopKRouter(config) + self.experts = Qwen3_5MoeExperts(config) + self.shared_expert = Qwen3_5MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output + + expert_output += shared_expert_output + expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) + return expert_output + + +class Qwen3_5MoeRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Qwen3_5Moe is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Qwen3_5MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3_5MoeAttention(config, layer_idx) + self.mlp = Qwen3_5MoeSparseMoeBlock(config) + self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3_5MoePreTrainedModel(PreTrainedModel): + config: Qwen3_5MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3_5MoeDecoderLayer", "Qwen3_5MoeVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] + _can_record_outputs = { + "router_logits": OutputRecorder(Qwen3_5MoeSparseMoeBlock, index=1), + "hidden_states": Qwen3_5MoeDecoderLayer, + "attentions": Qwen3_5MoeAttention, + } + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Qwen3_5MoeGatedDeltaNet): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3_5MoeRMSNorm): + init.zeros_(module.weight) + elif isinstance(module, Qwen3_5MoeExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3_5MoeSparseMoeBlock): + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3_5MoeVisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +class Qwen3_5MoeVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3_5MoeVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3_5MoeVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3_5MoeVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen3_5MoeVisionAttention(nn.Module): + def __init__(self, config: Qwen3_5MoeVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3_5MoeVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3_5MoeVisionAttention(config=config) + self.mlp = Qwen3_5MoeVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3_5MoeVisionModel(Qwen3_5MoePreTrainedModel): + config: Qwen3_5MoeVisionConfig + _no_split_modules = ["Qwen3_5MoeVisionBlock"] + _can_record_outputs = { + "hidden_states": Qwen3_5MoeVisionBlock, + "attentions": Qwen3_5MoeVisionAttention, + } + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3_5MoeVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3_5MoeVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3_5MoeVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.gradient_checkpointing = False + + self.post_init() + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + @check_model_inputs + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + + merged_hidden_states = self.merger(hidden_states) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3_5MoeModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + router_logits: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3_5Moe causal language model (or autoregressive) outputs. + """ +) +class Qwen3_5MoeCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + aux_loss: torch.FloatTensor | None = None + + +class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel): + def __init__(self, config: Qwen3_5MoeTextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [Qwen3_5MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5MoeDynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return Qwen3_5MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +@auto_docstring +class Qwen3_5MoeModel(Qwen3_5MoePreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3_5MoeConfig + _no_split_modules = ["Qwen3_5MoeTextDecoderLayer", "Qwen3_5MoeVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3_5MoeVisionModel._from_config(config.vision_config) + self.language_model = Qwen3_5MoeTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3_5Moe use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to separate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) + image_embeds = vision_output.pooler_output + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + vision_output.pooler_output = image_embeds + + return vision_output + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor | None = None, + video_features: torch.FloatTensor | None = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None: + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None: + torch_compilable_check( + inputs_embeds[special_video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", + ) + return special_image_mask, special_video_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3_5MoeModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + if self.rope_deltas is None or past_key_values_length == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return Qwen3_5MoeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: Qwen3_5MoeTextConfig + _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5MoeTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Qwen3_5MoeDynamicCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3_5MoeForCausalLM + + >>> model = Qwen3_5MoeForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class Qwen3_5MoeForConditionalGeneration(Qwen3_5MoePreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3_5MoeConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5MoeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + return self.model.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3_5MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration + + >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return Qwen3_5MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # Qwen3_5Moe position_ids are prepared with rope_deltas + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None: + vision_positions, rope_deltas = self.model.get_rope_index( + model_inputs.get("input_ids", None), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.model.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + elif "position_ids" in model_inputs: + batch_size, seq_length = model_inputs["position_ids"].shape + device = model_inputs["position_ids"].device + position_ids = torch.arange(seq_length, device=device) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + delta = cache_position[0] + self.model.rope_deltas + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + vision_positions = position_ids + delta.expand_as(position_ids) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + text_positions = model_inputs["position_ids"][None, ...] + model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0) + + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Qwen3_5Moe use timestamps and remove second_per_grid_ts + # Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + # video_nums: (batch_size,) + # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), + # but Qwen3_5Moe append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw + if video_grid_thw is not None: + cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) + cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) + # Find video boundaries in cumulative_frame_counts + video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) + # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] + video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3_5MoeVisionModel", + "Qwen3_5MoeTextModel", + "Qwen3_5MoeModel", + "Qwen3_5MoeForCausalLM", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3_5MoePreTrainedModel", +] diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py new file mode 100644 index 000000000000..3369cf363ee9 --- /dev/null +++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py @@ -0,0 +1,464 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3.5Moe model.""" + +import torch + +from ... import initialization as init +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPooling +from ...modeling_rope_utils import RopeParameters +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ..qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig +from ..qwen3_5.modeling_qwen3_5 import ( + Qwen3_5GatedDeltaNet, + Qwen3_5MLP, + Qwen3_5Model, + Qwen3_5TextModel, + Qwen3_5TextRotaryEmbedding, + Qwen3_5VisionModel, + Qwen3_5VisionRotaryEmbedding, +) +from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig +from ..qwen3_next.modeling_qwen3_next import ( + Qwen3NextAttention, + Qwen3NextDecoderLayer, + Qwen3NextDynamicCache, + Qwen3NextExperts, + Qwen3NextForCausalLM, + Qwen3NextPreTrainedModel, + Qwen3NextRMSNorm, + Qwen3NextSparseMoeBlock, +) +from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from ..qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeCausalLMOutputWithPast, + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModelOutputWithPast, + Qwen3VLMoeTextTopKRouter, +) + + +logger = logging.get_logger(__name__) + + +class Qwen3_5MoeTextConfig(Qwen3NextConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5MoeTextModel`]. It is used to instantiate a + Qwen3.5-MoE model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 248320): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 256): + Number of routed experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*): + End of stream token id. + + ```python + >>> from transformers import Qwen3_5MoeTextModel, Qwen3_5MoeTextConfig + + >>> # Initializing a Qwen3.5-MoE style configuration + >>> configuration = Qwen3_5MoeTextConfig() + + >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration + >>> model = Qwen3_5MoeTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_5_moe_text" + base_config_key = "text_config" + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.shared_expert.gate_proj": "colwise", + "layers.*.mlp.shared_expert.up_proj": "colwise", + "layers.*.mlp.shared_expert.down_proj": "rowwise", + } + + def __init__( + self, + vocab_size=248320, + hidden_size=2048, + num_hidden_layers=40, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=8, + num_experts=256, + output_router_logits=False, + router_aux_loss_coef=0.001, + layer_types=None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + kwargs["ignore_keys_at_rope_validation"] = {"mrope_section", "mrope_interleaved"} + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + del self.intermediate_size + del self.decoder_sparse_step + del self.norm_topk_prob + del self.mlp_only_layers + + +class Qwen3_5MoeVisionConfig(Qwen3_5VisionConfig): + pass + + +class Qwen3_5MoeConfig(Qwen3VLConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3_5MoeModel`]. It is used to instantiate a + Qwen3.5-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen3.5-35B-A3B-Instruct [Qwen/Qwen3.5-35B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3.5-35B-A3B-Instruct). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 248056): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 248057): + The video token index to encode the image prompt. + vision_start_token_id (`int`, *optional*, defaults to 248053): + The start token index to encode the image prompt. + vision_end_token_id (`int`, *optional*, defaults to 248054): + The end token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings. + + ```python + >>> from transformers import Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeConfig + + >>> # Initializing a Qwen3.5-MoE style configuration + >>> configuration = Qwen3_5MoeConfig() + + >>> # Initializing a model from the Qwen3.5-35B-A3B style configuration + >>> model = Qwen3_5MoeForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_5_moe" + sub_configs = {"vision_config": Qwen3_5MoeVisionConfig, "text_config": Qwen3_5MoeTextConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__( + text_config=text_config, + vision_config=vision_config, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Qwen3_5MoeVisionRotaryEmbedding(Qwen3_5VisionRotaryEmbedding): + pass + + +class Qwen3_5MoeTextRotaryEmbedding(Qwen3_5TextRotaryEmbedding): + pass + + +class Qwen3_5MoeDynamicCache(Qwen3NextDynamicCache): + pass + + +class Qwen3_5MoeGatedDeltaNet(Qwen3_5GatedDeltaNet): + pass + + +class Qwen3_5MoeAttention(Qwen3NextAttention): + pass + + +class Qwen3_5MoeMLP(Qwen3_5MLP): + pass + + +class Qwen3_5MoeExperts(Qwen3NextExperts): + pass + + +class Qwen3_5MoeTopKRouter(Qwen3VLMoeTextTopKRouter): + pass + + +class Qwen3_5MoeSparseMoeBlock(Qwen3NextSparseMoeBlock): + pass + + +class Qwen3_5MoeRMSNorm(Qwen3NextRMSNorm): + pass + + +class Qwen3_5MoeDecoderLayer(Qwen3NextDecoderLayer): + def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): + GradientCheckpointingLayer.__init__(self) + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3_5MoeAttention(config, layer_idx) + self.mlp = Qwen3_5MoeSparseMoeBlock(config) + self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class Qwen3_5MoePreTrainedModel(Qwen3NextPreTrainedModel): + _no_split_modules = ["Qwen3_5MoeDecoderLayer", "Qwen3_5MoeVisionBlock"] + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, Qwen3_5MoeGatedDeltaNet): + init.ones_(module.dt_bias) + init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif isinstance(module, Qwen3_5MoeRMSNorm): + init.zeros_(module.weight) + elif isinstance(module, Qwen3_5MoeExperts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3_5MoeSparseMoeBlock): + init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Qwen3_5MoeVisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +class Qwen3_5MoeVisionModel(Qwen3_5VisionModel): + pass + + +class Qwen3_5MoeModelOutputWithPast(Qwen3VLMoeModelOutputWithPast): + router_logits: tuple[torch.FloatTensor] | None = None + + +class Qwen3_5MoeCausalLMOutputWithPast(Qwen3VLMoeCausalLMOutputWithPast): + pass + + +class Qwen3_5MoeTextModel(Qwen3_5TextModel): + pass + + +class Qwen3_5MoeModel(Qwen3_5Model): + pass + + +class Qwen3_5MoeForCausalLM(Qwen3NextForCausalLM): + config: Qwen3_5MoeTextConfig + _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3_5MoeTextModel(config) + + +class Qwen3_5MoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration + + >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + super().forward(**super_kwargs) + + def get_video_features( + self, + **super_kwargs, + ) -> tuple | BaseModelOutputWithPooling: + return super().get_video_features(**super_kwargs) + + def get_image_features( + self, + **super_kwargs, + ) -> tuple | BaseModelOutputWithPooling: + return super().get_image_features(**super_kwargs) + + +__all__ = [ + "Qwen3_5MoeConfig", + "Qwen3_5MoeTextConfig", + "Qwen3_5MoeVisionModel", + "Qwen3_5MoeTextModel", + "Qwen3_5MoeModel", + "Qwen3_5MoeForCausalLM", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3_5MoePreTrainedModel", +] diff --git a/tests/models/qwen3_5/__init__.py b/tests/models/qwen3_5/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py new file mode 100644 index 000000000000..da025439acb4 --- /dev/null +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -0,0 +1,402 @@ +# Copyright 2026 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3.5 model.""" + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Qwen3_5Config, + Qwen3_5ForCausalLM, + Qwen3_5ForConditionalGeneration, + Qwen3_5Model, + Qwen3_5TextConfig, + Qwen3_5TextModel, + ) + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache + + +class Qwen3_5TextModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Qwen3_5TextModel + causal_lm_class = Qwen3_5ForCausalLM + + def __init__(self, parent): + super().__init__(parent=parent) + self.layer_types = ["full_attention", "linear_attention"] + self.linear_conv_kernel_dim = 2 + self.linear_key_head_dim = 16 + self.linear_value_head_dim = 16 + self.linear_num_key_heads = 4 + self.linear_num_value_heads = 8 + + +@require_torch +class Qwen3_5TextModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Qwen3_5TextModelTester + config_class = Qwen3_5TextConfig + model_split_percents = [0.5, 0.8, 0.9] + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" + self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) + + # (batch, kv heads, seq_length, head_dim) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + expected_shape = (batch_size, num_heads, seq_length, head_dim) + + attention_layer_indices = past_key_values.transformer_layers + self.assertListEqual( + [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + self.assertListEqual( + [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + + def _check_caches_are_equal(self, cache1, cache2): + "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" + if not len(cache1) == len(cache2): + raise ValueError("Both caches do not have the same number of layers.") + + num_layers = len(cache1) + for idx in range(num_layers): + if cache1.key_cache[idx] is not None: + torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) + torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + + def test_attention_outputs(self): + "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + + @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip("Intentionally not reversable (no changes) as only load time within a VLM depends on this") + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + pass + + +class Qwen3_5VisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=16, + text_config={ + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": 8, + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 37, + "max_position_embeddings": 512, + "model_type": "qwen3_vl", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "layer_types": ["full_attention", "linear_attention"], + "num_key_value_heads": 2, + "rope_theta": 10000, + "tie_word_embeddings": True, + "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + "linear_conv_kernel_dim": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_num_key_heads": 4, + "linear_num_value_heads": 8, + }, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "gelu_pytorch_tanh", + "intermediate_size": 32, + "out_hidden_size": 32, + "hidden_size": 32, + "num_heads": 4, + "patch_size": 16, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + "num_position_embeddings": 16, + }, + image_token_id=3, + video_token_id=4, + vision_start_token_id=5, + vision_end_token_id=6, + tie_word_embeddings=True, + is_training=True, + ): + self.parent = parent + self.ignore_index = ignore_index + self.is_training = is_training + + self.vision_config = vision_config + self.text_config = text_config + + self.vocab_size = text_config["vocab_size"] + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.head_dim = text_config["head_dim"] + self.hidden_size = text_config["hidden_size"] + self.intermediate_size = text_config["intermediate_size"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.num_key_value_heads = text_config["num_key_value_heads"] + self.rope_theta = text_config["rope_theta"] + self.rope_parameters = text_config["rope_parameters"] + self.hidden_act = text_config["hidden_act"] + self.max_position_embeddings = text_config["max_position_embeddings"] + self.model_type = text_config["model_type"] + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.tie_word_embeddings = tie_word_embeddings + + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_image_tokens = 32 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return Qwen3_5Config( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + vision_end_token_id=self.vision_end_token_id, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Qwen3_5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Qwen3_5ForConditionalGeneration`. + """ + + all_model_classes = ( + ( + Qwen3_5Model, + Qwen3_5ForConditionalGeneration, + ) + if is_torch_available() + else () + ) + model_split_percents = [0.5, 0.8, 0.9] + + def setUp(self): + self.model_tester = Qwen3_5VisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3_5Config, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" + self.assertIsInstance(past_key_values, Qwen3_5DynamicCache) + + # (batch, kv heads, seq_length, head_dim) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + expected_shape = (batch_size, num_heads, seq_length, head_dim) + + attention_layer_indices = past_key_values.transformer_layers + self.assertListEqual( + [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + self.assertListEqual( + [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + + def _check_caches_are_equal(self, cache1, cache2): + "Qwen3.5 has a special Cache as it alternates with gated deltanet layers" + if not len(cache1) == len(cache2): + raise ValueError("Both caches do not have the same number of layers.") + + num_layers = len(cache1) + for idx in range(num_layers): + if cache1.key_cache[idx] is not None: + torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) + torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + + def test_attention_outputs(self): + "Needs to be overwritten as Qwen3.5 alternates between attention layers and gated deltanet layers." + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual( + len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.text_config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual( + len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + self.assertListEqual( + list(attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len] + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual( + len(self_attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len] + ) + + @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.") + def test_multi_gpu_data_parallel_forward(self): + pass diff --git a/tests/models/qwen3_5_moe/__init__.py b/tests/models/qwen3_5_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py new file mode 100644 index 000000000000..edd88860b596 --- /dev/null +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -0,0 +1,491 @@ +# Copyright 2026 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Qwen3.5 model.""" + +import copy +import os +import re +import tempfile +import unittest + +from safetensors.torch import load_file + +from transformers import is_torch_available +from transformers.conversion_mapping import get_model_conversion_mapping +from transformers.core_model_loading import WeightRenaming, process_target_pattern +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + compare_state_dicts, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Qwen3_5MoeConfig, + Qwen3_5MoeForCausalLM, + Qwen3_5MoeForConditionalGeneration, + Qwen3_5MoeModel, + Qwen3_5MoeTextConfig, + Qwen3_5MoeTextModel, + ) + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache + + +class Qwen3_5MoeTextModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Qwen3_5MoeTextModel + causal_lm_class = Qwen3_5MoeForCausalLM + + def __init__(self, parent): + super().__init__(parent=parent) + self.layer_types = ["full_attention", "linear_attention"] + self.linear_conv_kernel_dim = 2 + self.linear_key_head_dim = 16 + self.linear_value_head_dim = 16 + self.linear_num_key_heads = 4 + self.linear_num_value_heads = 8 + + +@require_torch +class Qwen3_5MoeTextModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Qwen3_5MoeTextModelTester + config_class = Qwen3_5MoeTextConfig + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" + self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) + + # (batch, kv heads, seq_length, head_dim) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + expected_shape = (batch_size, num_heads, seq_length, head_dim) + + attention_layer_indices = past_key_values.transformer_layers + self.assertListEqual( + [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + self.assertListEqual( + [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + + def _check_caches_are_equal(self, cache1, cache2): + "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" + if not len(cache1) == len(cache2): + raise ValueError("Both caches do not have the same number of layers.") + + num_layers = len(cache1) + for idx in range(num_layers): + if cache1.key_cache[idx] is not None: + torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) + torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + + def test_attention_outputs(self): + "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) + self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + + def test_reverse_loading_mapping(self, check_keys_were_modified=True): + """ + Overwritten to check for the moe portion but ignore the prefix as it results into a noop + (except we have a VLM struct initially) + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at + # lest one MoE layer here to check the mapping + config_to_set = config.get_text_config(decoder=True) + config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE + config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5... + config_to_set.mlp_only_layers = [0] # same but for qwens + config_to_set.num_dense_layers = 1 # lfm2_moe + + for model_class in self.all_model_classes: + # Each individual model is a subtest + with self.subTest(model_class.__name__): + model = model_class(copy.deepcopy(config)) + # Skip if no conversions + conversions = get_model_conversion_mapping(model, add_legacy=False) + if len(conversions) == 0: + self.skipTest("No conversion found for this model") + + # Find the model keys, so the targets according to the conversions + model_keys = list(model.state_dict().keys()) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Serialize with reverse mapping + model.save_pretrained(tmpdirname) + state_dict = load_file(os.path.join(tmpdirname, "model.safetensors")) + # Get all the serialized keys that we just saved according to the reverse mapping + serialized_keys = list(state_dict.keys()) + + if check_keys_were_modified: + # They should be different, otherwise we did not perform any mapping + self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") + + # Check that for each conversion entry, we at least map to one key + for conversion in conversions: + for source_pattern in conversion.source_patterns: + # Sometimes the mappings specify keys that are tied, so absent from the saved state dict + if isinstance(conversion, WeightRenaming): + # We need to revert the target pattern to make it compatible with regex search + target_pattern_reversed = conversion.target_patterns[0] + captured_group = process_target_pattern(source_pattern)[1] + if captured_group: + target_pattern_reversed = target_pattern_reversed.replace(r"\1", captured_group) + if any(re.search(target_pattern_reversed, k) for k in model.all_tied_weights_keys.keys()): + continue + num_matches = sum(re.search(source_pattern, key) is not None for key in serialized_keys) + + # Key change: special case to load causal lm within vlm + if source_pattern == "^model.language_model": + continue + + self.assertTrue( + num_matches > 0, + f"`{source_pattern}` in `{conversion}` did not match any of the source keys. " + "This indicates whether that the pattern is not properly written, ot that it could not be reversed correctly", + ) + + # If everything is still good at this point, let's test that we perform the same operations both when + # reverting ops from `from_pretrained` and from `__init__` + with tempfile.TemporaryDirectory() as tmpdirname: + # The model was instantiated from __init__ before being saved + model.save_pretrained(tmpdirname) + state_dict_saved_from_init = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Now reload it + model_reloaded = model_class.from_pretrained(tmpdirname) + + # Make sure both loaded state_dict are identical + self.assertTrue(compare_state_dicts(model_reloaded.state_dict(), model.state_dict())) + + # The model was instantiated from `from_pretrained` before being saved + model_reloaded.save_pretrained(tmpdirname) + state_dict_saved_from_pretrained = load_file(os.path.join(tmpdirname, "model.safetensors")) + + # Make sure both saved state_dict are identical + self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained)) + + @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.") + def test_multi_gpu_data_parallel_forward(self): + pass + + +class Qwen3_5MoeVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=16, + text_config={ + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": 8, + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 37, + "max_position_embeddings": 512, + "model_type": "qwen3_vl", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "layer_types": ["full_attention", "linear_attention"], + "num_key_value_heads": 2, + "rope_theta": 10000, + "tie_word_embeddings": True, + "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + "linear_conv_kernel_dim": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_num_key_heads": 4, + "linear_num_value_heads": 8, + "moe_intermediate_size": 16, + "shared_expert_intermediate_size": 36, + "num_experts_per_tok": 2, + "num_experts": 8, + }, + vision_config={ + "depth": 2, + "in_chans": 3, + "hidden_act": "gelu_pytorch_tanh", + "intermediate_size": 32, + "out_hidden_size": 32, + "hidden_size": 32, + "num_heads": 4, + "patch_size": 16, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + "num_position_embeddings": 16, + }, + image_token_id=3, + video_token_id=4, + vision_start_token_id=5, + vision_end_token_id=6, + tie_word_embeddings=True, + is_training=True, + ): + self.parent = parent + self.ignore_index = ignore_index + self.is_training = is_training + + self.vision_config = vision_config + self.text_config = text_config + + self.vocab_size = text_config["vocab_size"] + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.head_dim = text_config["head_dim"] + self.hidden_size = text_config["hidden_size"] + self.intermediate_size = text_config["intermediate_size"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.num_key_value_heads = text_config["num_key_value_heads"] + self.rope_theta = text_config["rope_theta"] + self.rope_parameters = text_config["rope_parameters"] + self.hidden_act = text_config["hidden_act"] + self.max_position_embeddings = text_config["max_position_embeddings"] + self.model_type = text_config["model_type"] + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.tie_word_embeddings = tie_word_embeddings + + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_image_tokens = 32 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return Qwen3_5MoeConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + vision_end_token_id=self.vision_end_token_id, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.video_token_id] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id + input_ids[:, self.num_image_tokens] = self.image_token_id + input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Qwen3_5MoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Qwen3_5MoeForConditionalGeneration`. + """ + + all_model_classes = ( + ( + Qwen3_5MoeModel, + Qwen3_5MoeForConditionalGeneration, + ) + if is_torch_available() + else () + ) + + def setUp(self): + self.model_tester = Qwen3_5MoeVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Qwen3_5MoeConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" + self.assertIsInstance(past_key_values, Qwen3_5MoeDynamicCache) + + # (batch, kv heads, seq_length, head_dim) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + expected_shape = (batch_size, num_heads, seq_length, head_dim) + + attention_layer_indices = past_key_values.transformer_layers + self.assertListEqual( + [past_key_values.key_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + self.assertListEqual( + [past_key_values.value_cache[idx].shape for idx in attention_layer_indices], + [expected_shape] * len(attention_layer_indices), + ) + + def _check_caches_are_equal(self, cache1, cache2): + "Qwen3.5 Moe has a special Cache as it alternates with gated deltanet layers" + if not len(cache1) == len(cache2): + raise ValueError("Both caches do not have the same number of layers.") + + num_layers = len(cache1) + for idx in range(num_layers): + if cache1.key_cache[idx] is not None: + torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) + torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + + def test_attention_outputs(self): + "Needs to be overwritten as Qwen3.5 Moe alternates between attention layers and gated deltanet layers." + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual( + len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.text_config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual( + len(attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + self.assertListEqual( + list(attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len] + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual( + len(self_attentions), sum(layer == "full_attention" for layer in config.text_config.layer_types) + ) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), [config.text_config.num_attention_heads, seq_len, seq_len] + ) + + @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.") + def test_multi_gpu_data_parallel_forward(self): + pass diff --git a/utils/check_repo.py b/utils/check_repo.py index fb5e54097bc4..598ae9e90a36 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -72,6 +72,8 @@ "Qwen2_5_VisionTransformerPretrainedModel", "Qwen3VLVisionModel", "Qwen3VLMoeVisionModel", + "Qwen3_5VisionModel", + "Qwen3_5MoeVisionModel", "SwitchTransformersStack", "SiglipTextTransformer", "Siglip2TextTransformer", @@ -176,6 +178,8 @@ "Qwen3VLMoeModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3VLMoeForConditionalGeneration. "Qwen3VLTextModel", # Building part of bigger (tested) model. "Qwen3VLMoeTextModel", # Building part of bigger (tested) model. + "Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration. + "Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration. "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py index 9d6c38cda388..ed1bc84a376d 100644 --- a/utils/models_to_deprecate.py +++ b/utils/models_to_deprecate.py @@ -114,6 +114,8 @@ "qwen2_vl": ["qwen2_vl_text"], "qwen3_vl_moe": ["qwen3_vl_moe_text"], "qwen3_vl": ["qwen3_vl_text"], + "qwen3_5": ["qwen3_5text"], + "qwen3_5_moe": ["qwen3_5_moe_text"], "rt_detr": ["rt_detr_resnet"], "sam2": ["sam2_hiera_det_model", "sam2_vision_model"], "sam": ["sam_hq_vision_model", "sam_vision_model"],