From fc386207d6c0b9157ce1cd750a9faf72a9176532 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Mon, 9 Feb 2026 18:36:21 +0800 Subject: [PATCH 1/2] support qwen3.5 series --- mlx_vlm/models/base.py | 3 +- mlx_vlm/models/qwen3_5/__init__.py | 2 + mlx_vlm/models/qwen3_5/config.py | 84 +++ mlx_vlm/models/qwen3_5/language.py | 653 ++++++++++++++++++++++ mlx_vlm/models/qwen3_5/qwen3_5.py | 131 +++++ mlx_vlm/models/qwen3_5/vision.py | 4 + mlx_vlm/models/qwen3_5_moe/__init__.py | 2 + mlx_vlm/models/qwen3_5_moe/config.py | 86 +++ mlx_vlm/models/qwen3_5_moe/language.py | 112 ++++ mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py | 62 ++ mlx_vlm/models/qwen3_5_moe/vision.py | 4 + mlx_vlm/models/qwen3_vl/vision.py | 2 +- mlx_vlm/prompt_utils.py | 4 + 13 files changed, 1147 insertions(+), 2 deletions(-) create mode 100644 mlx_vlm/models/qwen3_5/__init__.py create mode 100644 mlx_vlm/models/qwen3_5/config.py create mode 100644 mlx_vlm/models/qwen3_5/language.py create mode 100644 mlx_vlm/models/qwen3_5/qwen3_5.py create mode 100644 mlx_vlm/models/qwen3_5/vision.py create mode 100644 mlx_vlm/models/qwen3_5_moe/__init__.py create mode 100644 mlx_vlm/models/qwen3_5_moe/config.py create mode 100644 mlx_vlm/models/qwen3_5_moe/language.py create mode 100644 mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py create mode 100644 mlx_vlm/models/qwen3_5_moe/vision.py diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index 62c91736c..b60b3a8df 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -6,7 +6,8 @@ import mlx.core as mx import mlx.nn as nn -from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention +from mlx_lm.models.base import create_attention_mask, create_ssm_mask, scaled_dot_product_attention + from PIL import Image diff --git a/mlx_vlm/models/qwen3_5/__init__.py b/mlx_vlm/models/qwen3_5/__init__.py new file mode 100644 index 000000000..2cf179823 --- /dev/null +++ b/mlx_vlm/models/qwen3_5/__init__.py @@ -0,0 +1,2 @@ +from .config import ModelConfig, TextConfig, VisionConfig +from .qwen3_5 import LanguageModel, Model, VisionModel diff --git a/mlx_vlm/models/qwen3_5/config.py b/mlx_vlm/models/qwen3_5/config.py new file mode 100644 index 000000000..7094e7d96 --- /dev/null +++ b/mlx_vlm/models/qwen3_5/config.py @@ -0,0 +1,84 @@ +import inspect +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +from ..base import BaseModelConfig +from ..qwen3_vl.config import VisionConfig as Qwen3VLVisionConfig + + +@dataclass +class VisionConfig(Qwen3VLVisionConfig): + model_type: str = "qwen3_5" + + def __post_init__(self): + if self.deepstack_visual_indexes is not None and len(self.deepstack_visual_indexes) > 0: + raise ValueError(f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}") + self.deepstack_visual_indexes = [] + + +@dataclass +class TextConfig(BaseModelConfig): + model_type: str + hidden_size: int + intermediate_size: int + linear_num_value_heads: int + linear_num_key_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int + num_hidden_layers: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + max_position_embeddings: int + tie_word_embeddings: bool = False + attention_bias: bool = False + head_dim: Optional[int] = None + rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( + default_factory=lambda: {"type": "default", "mrope_section": [11, 11, 10], "rope_theta": 100000, "partial_rotary_factor": 0.25} + ) + full_attention_interval: int = 4 + + def __post_init__(self): + if self.rope_parameters: + # Normalize rope_parameters keys (accept both 'rope_type' and 'type') + if "type" not in self.rope_parameters and "rope_type" in self.rope_parameters: + self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") + + required_keys = {"mrope_section", "type", "rope_theta", "partial_rotary_factor"} + if not all(key in self.rope_parameters for key in required_keys): + raise ValueError(f"rope_parameters must contain keys {required_keys}") + + +@dataclass +class ModelConfig(BaseModelConfig): + text_config: TextConfig + vision_config: VisionConfig + model_type: str + ignore_index: int = -100 + image_token_id: int = 248056 + video_token_id: int = 248057 + image_token_index: Optional[int] = None + video_token_index: Optional[int] = None + vision_start_token_id: int = 248045 + vision_end_token_id: int = 248046 + vocab_size: int = 248320 + eos_token_id: Optional[List[int]] = None + + def __post_init__(self): + if self.image_token_index is None: + self.image_token_index = self.image_token_id + if self.video_token_index is None: + self.video_token_index = self.video_token_id + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py new file mode 100644 index 000000000..384490610 --- /dev/null +++ b/mlx_vlm/models/qwen3_5/language.py @@ -0,0 +1,653 @@ +from typing import Optional, Any + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from ..base import ( + LanguageModelOutput, + create_attention_mask, + create_ssm_mask, + scaled_dot_product_attention, +) +from ..cache import ArraysCache, KVCache +from mlx_lm.models.gated_delta import gated_delta_update +from mlx_lm.models.activations import swiglu + +from .config import ModelConfig, TextConfig + + +class Qwen3_5RotaryEmbedding: + def __init__( + self, dim, max_position_embeddings=2048, base=10000, mrope_section=[11, 11, 0] + ): + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / ( + self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim) + ) + self.inv_freq = inv_freq + + self.mrope_section = mrope_section + + def apply_interleaved_mrope(self, freqs, mrope_section): + freqs_t = freqs[0] + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def __call__(self, x, position_ids): + if position_ids.ndim == 2: + position_ids = mx.broadcast_to( + position_ids[None, ...], + (3, position_ids.shape[0], position_ids.shape[1]), + ) + + inv_freq_expanded = mx.broadcast_to( + self.inv_freq[None, None, :, None].astype(mx.float32), + (3, position_ids.shape[1], self.inv_freq.shape[0], 1), + ) + position_ids_expanded = position_ids[:, :, None, :].astype(mx.float32) + + freqs = inv_freq_expanded @ position_ids_expanded + freqs = mx.swapaxes(freqs, 2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) + sin = mx.sin(emb) + + return cos.astype(x.dtype), sin.astype(x.dtype) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mx.concatenate([-x2, x1], axis=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1): + cos = mx.expand_dims(cos, axis=unqueeze_dim) + sin = mx.expand_dims(sin, axis=unqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot = q[..., :rotary_dim] + q_pass = q[..., rotary_dim:] + + k_rot = k[..., :rotary_dim] + k_pass = k[..., rotary_dim:] + + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + q_embed = mx.concatenate([q_embed, q_pass], axis=-1) + k_embed = mx.concatenate([k_embed, k_pass], axis=-1) + + return q_embed, k_embed + + +class Qwen3_5RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = mx.ones(hidden_size) + + def __call__( + self, hidden_states: mx.array, gate: mx.array | None = None + ) -> mx.array: + x = mx.fast.rms_norm(hidden_states, self.weight, self.eps) + if gate is not None: + x = swiglu(gate, x) + return x + + +class Qwen3_5Attention(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.num_key_value_heads = args.num_key_value_heads + self.num_attention_heads = args.num_attention_heads + self.head_dim = args.head_dim + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear( + args.hidden_size, + self.num_attention_heads * self.head_dim * 2, + bias=args.attention_bias, + ) + self.k_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.v_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + args.hidden_size, + bias=args.attention_bias, + ) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + + self.rotary_emb = Qwen3_5RotaryEmbedding( + int(self.head_dim * args.rope_parameters["partial_rotary_factor"]), + max_position_embeddings=args.max_position_embeddings, + base=args.rope_parameters["rope_theta"], + mrope_section=args.rope_parameters["mrope_section"], + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1), 2, axis=-1 + ) + gate = gate.reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + kv_seq_len = keys.shape[-2] + + if position_ids is None: + kv_seq_len += cache.offset + 1 + position_ids = mx.arange(cache.offset, cache.offset + L) + position_ids = mx.expand_dims(position_ids, axis=0) + position_ids = mx.tile(position_ids, (3, 1, 1)) + else: + kv_seq_len += cache.offset + 1 if cache is not None else 0 + + cos, sin = self.rotary_emb(values, position_ids) + + if mask is not None and isinstance(mask, mx.array): + mask = mask[..., :kv_seq_len] + + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output * mx.sigmoid(gate)) + + +class Qwen3_5MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) + + +class Qwen3_5GatedDeltaNet(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + if self.num_v_heads % self.num_k_heads != 0: + raise ValueError( + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" + ) + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=0, + ) + + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.dt_bias = mx.ones(self.num_v_heads) + + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) + self.A_log = mx.log(A) + + self.norm = Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, S, _ = inputs.shape + + mixed_qkv = self.in_proj_qkv(inputs) + + z = self.in_proj_z(inputs) + z = z.reshape(B, S, -1, self.head_v_dim) + + b = self.in_proj_b(inputs) + a = self.in_proj_a(inputs) + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, + ) + + if mask is not None: + mixed_qkv = mx.where(mask[..., None], mixed_qkv, 0) + conv_input = mx.concatenate([conv_state, mixed_qkv], axis=1) + if cache is not None: + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) + + q, k, v = [ + t.reshape(B, S, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + + state = cache[1] if cache else None + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, state = gated_delta_update( + q, + k, + v, + a, + b, + self.A_log, + self.dt_bias, + state, + mask, + use_kernel=not self.training, + ) + + if cache is not None: + cache[1] = state + + out = self.norm(out, z) + return self.out_proj(out.reshape(B, S, -1)) + + +class Qwen3_5DecoderLayer(nn.Module): + def __init__(self, args: TextConfig, layer_idx: int): + super().__init__() + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: + self.linear_attn = Qwen3_5GatedDeltaNet(args) + else: + self.self_attn = Qwen3_5Attention(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.mlp = Qwen3_5MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + if self.is_linear: + r = self.linear_attn(self.input_layernorm(x), mask, cache) + else: + r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids) + h = x + r + out = h + self.mlp(self.post_attention_layernorm(h)) + return out + + +class Qwen3_5Model(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen3_5DecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.ssm_idx = 0 + self.fa_idx = args.full_attention_interval - 1 + + def __call__( + self, + inputs: mx.array, + inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + cache=None, + position_ids: Optional[mx.array] = None, + ): + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds + + if cache is None: + cache = [None] * len(self.layers) + + fa_mask = create_attention_mask(h, cache[self.fa_idx]) + ssm_mask = create_ssm_mask(h, cache[self.ssm_idx]) + + for layer, c in zip(self.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + h = layer(h, mask, c, position_ids) + + return self.norm(h) + + +class LanguageModel(nn.Module): + def __init__(self, args: TextConfig, config: ModelConfig = None): + super().__init__() + self.args = args + self.config = config + self.model_type = args.model_type + self.model = Qwen3_5Model(args) + self._rope_deltas = None + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def get_rope_index( + self, + input_ids: mx.array, + image_grid_thw: Optional[mx.array] = None, + video_grid_thw: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + ): + batch_size, seq_length = input_ids.shape + position_ids = mx.arange(seq_length, dtype=mx.int32) + position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length)) + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = mx.ones_like(input_ids) + position_ids = mx.ones( + (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + input_ids = mx.where( + attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids) + ) + image_nums, video_nums = 0, 0 + vision_start_indices = mx.sum( + mx.where( + input_ids == vision_start_token_id, + mx.arange(input_ids.shape[0]), + mx.zeros_like(input_ids), + ) + ) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum().item() + video_nums = (vision_tokens == video_token_id).sum().item() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + index = mx.arange(text_len).reshape(1, text_len) + index = mx.broadcast_to(index, (3, text_len)) + index = index + st_idx + llm_pos_ids_list.append(index) + t_index = mx.arange(llm_grid_t).reshape(llm_grid_t, 1) + t_index = mx.broadcast_to( + t_index, (llm_grid_t, llm_grid_h * llm_grid_w) + ) + t_index = t_index.flatten() + + h_index = mx.arange(llm_grid_h).reshape(1, llm_grid_h, 1) + h_index = mx.broadcast_to( + h_index, (llm_grid_t, llm_grid_h, llm_grid_w) + ) + h_index = h_index.flatten() + + w_index = mx.arange(llm_grid_w).reshape(1, 1, llm_grid_w) + w_index = mx.broadcast_to( + w_index, (llm_grid_t, llm_grid_h, llm_grid_w) + ) + w_index = w_index.flatten() + + llm_pos_ids_list.append( + mx.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + + t_index = mx.arange(text_len).reshape(1, text_len) + t_index = mx.broadcast_to(t_index, (3, text_len)) + + llm_pos_ids_list.append(t_index + st_idx) + + llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + mask = mx.array(attention_mask[i] == 1) + expanded_mask = mx.expand_dims(mask, axis=0) + expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0])) + expanded_positions = mx.expand_dims(llm_positions, axis=1) + new_positions = mx.where( + expanded_mask, expanded_positions, position_ids[:, i : i + 1, :] + ) + updated_position_ids = mx.concatenate( + [ + position_ids[:, :i, :], + new_positions, + position_ids[:, i + 1 :, :], + ], + axis=1, + ) + position_ids = updated_position_ids + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = mx.array(mrope_position_deltas)[0] + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1 + position_ids = mx.where( + attention_mask == 0, mx.ones_like(position_ids), position_ids + ) + position_ids = mx.expand_dims(position_ids[0], axis=0) + position_ids = mx.tile(position_ids, (3, 1, 1)) + max_position_ids = position_ids.max(0, keepdims=False)[0].max( + -1, keepdims=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1) + position_ids = mx.broadcast_to( + position_ids, (3, input_ids.shape[0], input_ids.shape[1]) + ) + mrope_position_deltas = mx.zeros( + [input_ids.shape[0], 1], + dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas + + def __call__( + self, + inputs: mx.array, + inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + cache=None, + **kwargs, + ): + position_ids = kwargs.pop("position_ids", None) + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + if pixel_values is not None: + self._rope_deltas = None + + cache_offset = 0 + if cache and cache[self.model.fa_idx] is not None: + offset = cache[self.model.fa_idx].offset + if isinstance(offset, int): + cache_offset = offset + elif isinstance(offset, mx.array): + cache_offset = (offset if offset.ndim == 0 else offset[0]).item() + else: + raise ValueError(f"Unexpected cache offset type: {type(offset)}") + + if position_ids is None and (mask is None or mask.ndim == 2): + if ( + ( + cache is not None + and cache[self.model.fa_idx] is not None + and (cache_offset == 0) + ) + or self._rope_deltas is None + or cache is None + ): + position_ids, rope_deltas = self.get_rope_index( + inputs, image_grid_thw, video_grid_thw, mask + ) + self._rope_deltas = rope_deltas + else: + batch_size, seq_length = inputs.shape + delta = mx.array( + cache_offset + self._rope_deltas if cache is not None else 0 + ) + position_ids = mx.arange(seq_length).reshape(1, -1) + position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length)) + + if cache_offset is not None: + if delta.ndim == 0: + delta = mx.expand_dims(delta, axis=0) + + if delta.shape[0] < batch_size: + delta = mx.tile(delta, (batch_size, 1)) + else: + delta = delta[:batch_size] + + position_ids = mx.add(position_ids, delta)[None, ...] + position_ids = mx.broadcast_to( + position_ids, (3, batch_size, seq_length) + ) + + out = self.model( + inputs, + cache=cache, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + ) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return LanguageModelOutput(logits=out) + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/mlx_vlm/models/qwen3_5/qwen3_5.py b/mlx_vlm/models/qwen3_5/qwen3_5.py new file mode 100644 index 000000000..b52581511 --- /dev/null +++ b/mlx_vlm/models/qwen3_5/qwen3_5.py @@ -0,0 +1,131 @@ + +import mlx.core as mx +import mlx.nn as nn +from typing import Optional + +from ..base import InputEmbeddingsFeatures +from ..qwen3_vl import Model as Qwen3VLModel +from ..qwen3_vl.qwen3_vl import masked_scatter + +from .config import ModelConfig +from .language import LanguageModel +from .vision import VisionModel + +class Model(Qwen3VLModel): + + def __init__(self, config: ModelConfig): + # only initialize nn.Module, skip the initialization of vision_tower and language_model in the parent class + nn.Module.__init__(self) + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config, config) + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + **kwargs, + ): + image_grid_thw = kwargs.get("image_grid_thw", None) + video_grid_thw = kwargs.get("video_grid_thw", None) + mask = kwargs.get("mask", None) + grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw + + if pixel_values is None: + # Reset position state for text-only generation + self.language_model._position_ids = None + self.language_model._rope_deltas = None + return InputEmbeddingsFeatures( + inputs_embeds=self.language_model.model.embed_tokens(input_ids) + ) + + dtype = self.vision_tower.patch_embed.proj.weight.dtype + pixel_values = pixel_values.astype(dtype) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + hidden_states, _ = self.vision_tower( + pixel_values, grid_thw + ) + + # Insert special image tokens in the input_ids + inputs_embeds, _ = self.merge_input_ids_with_image_features( + hidden_states, + inputs_embeds, + input_ids, + self.config.image_token_index, + self.config.video_token_index, + ) + + # Pre-calculate position_ids for chunked prefill + if image_grid_thw is not None or video_grid_thw is not None: + position_ids, rope_deltas = self.language_model.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, mask + ) + self.language_model._position_ids = position_ids + self.language_model._rope_deltas = rope_deltas + + return InputEmbeddingsFeatures( + inputs_embeds=inputs_embeds, + ) + + @staticmethod + def merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, image_token_index, video_token_index + ): + special_image_mask = input_ids == image_token_index + special_video_mask = input_ids == video_token_index + special_image_mask = special_image_mask | special_video_mask + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask[..., None] + special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape) + + n_image_features = image_features.shape[0] + n_image_mask_elements = special_image_mask.sum() + if n_image_mask_elements != image_features.size: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + inputs_embeds = masked_scatter( + inputs_embeds, special_image_mask, image_features + ) + + return inputs_embeds, special_image_mask + + def sanitize(self, weights): + # ignore mtp weights + weights = {key: value for key, value in weights.items() if "mtp." not in key} + + if self.config.text_config.tie_word_embeddings: + weights.pop("lm_head.weight", None) + + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + + sanitized_weights = {} + for key, value in weights.items(): + if "model" in key: + if "model.language_model" in key: + key = key.replace("model.language_model", "language_model.model") + elif "model.visual" in key: + key = key.replace("model.visual", "vision_tower") + elif "lm_head" in key: + key = key.replace("lm_head", "language_model.lm_head") + + if "conv1d.weight" in key and value.shape[-1] != 1: + value = value.moveaxis(2, 1) + if any(key.endswith(sfx) for sfx in norm_keys): + if value.ndim == 1: + value += 1.0 + + sanitized_weights[key] = value + + return sanitized_weights diff --git a/mlx_vlm/models/qwen3_5/vision.py b/mlx_vlm/models/qwen3_5/vision.py new file mode 100644 index 000000000..9b02abaf9 --- /dev/null +++ b/mlx_vlm/models/qwen3_5/vision.py @@ -0,0 +1,4 @@ +from ..qwen3_vl import VisionModel as Qwen3VLVisionModel + +class VisionModel(Qwen3VLVisionModel): + pass \ No newline at end of file diff --git a/mlx_vlm/models/qwen3_5_moe/__init__.py b/mlx_vlm/models/qwen3_5_moe/__init__.py new file mode 100644 index 000000000..27d7793b1 --- /dev/null +++ b/mlx_vlm/models/qwen3_5_moe/__init__.py @@ -0,0 +1,2 @@ +from .config import ModelConfig, TextConfig, VisionConfig +from .qwen3_5_moe import LanguageModel, Model, VisionModel diff --git a/mlx_vlm/models/qwen3_5_moe/config.py b/mlx_vlm/models/qwen3_5_moe/config.py new file mode 100644 index 000000000..368f22f07 --- /dev/null +++ b/mlx_vlm/models/qwen3_5_moe/config.py @@ -0,0 +1,86 @@ +import inspect +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +from ..base import BaseModelConfig +from ..qwen3_vl.config import VisionConfig as Qwen3VLVisionConfig + + +@dataclass +class VisionConfig(Qwen3VLVisionConfig): + model_type: str = "qwen3_5_moe" + + def __post_init__(self): + if self.deepstack_visual_indexes is not None and len(self.deepstack_visual_indexes) > 0: + raise ValueError(f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}") + self.deepstack_visual_indexes = [] + +@dataclass +class TextConfig(BaseModelConfig): + model_type: str + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + linear_num_value_heads: int + linear_num_key_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int + num_experts: int + num_experts_per_tok: int + shared_expert_intermediate_size: int + moe_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + max_position_embeddings: int + tie_word_embeddings: bool = False + attention_bias: bool = False + head_dim: Optional[int] = None + rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( + default_factory=lambda: {"type": "default", "mrope_section": [11, 11, 10], "rope_theta": 100000, "partial_rotary_factor": 0.25} + ) + full_attention_interval: int = 4 + + def __post_init__(self): + if self.rope_parameters: + # Normalize rope_parameters keys (accept both 'rope_type' and 'type') + if "type" not in self.rope_parameters and "rope_type" in self.rope_parameters: + self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") + + required_keys = {"mrope_section", "type", "rope_theta", "partial_rotary_factor"} + if not all(key in self.rope_parameters for key in required_keys): + raise ValueError(f"rope_parameters must contain keys {required_keys}") + + +@dataclass +class ModelConfig(BaseModelConfig): + text_config: TextConfig + vision_config: VisionConfig + model_type: str + ignore_index: int = -100 + image_token_id: int = 248056 + video_token_id: int = 248057 + image_token_index: Optional[int] = None + video_token_index: Optional[int] = None + vision_start_token_id: int = 248045 + vision_end_token_id: int = 248046 + vocab_size: int = 248320 + eos_token_id: Optional[List[int]] = None + + def __post_init__(self): + if self.image_token_index is None: + self.image_token_index = self.image_token_id + if self.video_token_index is None: + self.video_token_index = self.video_token_id + + @classmethod + def from_dict(cls, params): + + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) diff --git a/mlx_vlm/models/qwen3_5_moe/language.py b/mlx_vlm/models/qwen3_5_moe/language.py new file mode 100644 index 000000000..40ea49d5d --- /dev/null +++ b/mlx_vlm/models/qwen3_5_moe/language.py @@ -0,0 +1,112 @@ +from typing import Optional, Any + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.switch_layers import SwitchGLU + +from ..qwen3_5.language import ( + LanguageModel as Qwen3_5LanguageModel, + Qwen3_5Attention as Qwen3_5MoeAttention, + Qwen3_5GatedDeltaNet as Qwen3_5MoeGatedDeltaNet, + Qwen3_5MLP as Qwen3_5MoeMLP, + Qwen3_5Model, +) +from .config import ModelConfig, TextConfig + + +class Qwen3_5MoeSparseMoeBlock(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + + self.gate = nn.Linear(dim, num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + self.shared_expert = Qwen3_5MoeMLP(dim, shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(dim, 1, bias=False) + + def __call__( + self, + x: mx.array, + ) -> mx.array: + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] + scores = mx.take_along_axis(gates, inds, axis=-1) + scores = scores / scores.sum(axis=-1, keepdims=True) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + shared_y = self.shared_expert(x) + shared_y = mx.sigmoid(self.shared_expert_gate(x)) * shared_y + + return y + shared_y + + +class Qwen3_5MoeDecoderLayer(nn.Module): + def __init__(self, args: TextConfig, layer_idx: int): + super().__init__() + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: + self.linear_attn = Qwen3_5MoeGatedDeltaNet(args) + else: + self.self_attn = Qwen3_5MoeAttention(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.mlp = Qwen3_5MoeSparseMoeBlock(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + if self.is_linear: + r = self.linear_attn(self.input_layernorm(x), mask, cache) + else: + r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids) + h = x + r + out = h + self.mlp(self.post_attention_layernorm(h)) + return out + + +class Qwen3_5MoeModel(Qwen3_5Model): + + def __init__(self, args: TextConfig): + nn.Module.__init__(self) + self.args = args + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen3_5MoeDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.ssm_idx = 0 + self.fa_idx = args.full_attention_interval - 1 + + +class LanguageModel(Qwen3_5LanguageModel): + + def __init__(self, args: TextConfig, config: ModelConfig = None): + nn.Module.__init__(self) + self.args = args + self.config = config + self.model_type = args.model_type + self.model = Qwen3_5MoeModel(args) + self._rope_deltas = None + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py b/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py new file mode 100644 index 000000000..374f8f949 --- /dev/null +++ b/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py @@ -0,0 +1,62 @@ +import mlx.core as mx +import mlx.nn as nn + +from ..qwen3_5 import Model as Qwen3_5Model + +from .config import ModelConfig +from .language import LanguageModel +from .vision import VisionModel + +class Model(Qwen3_5Model): + + def __init__(self, config: ModelConfig): + # only initialize nn.Module, skip the initialization of vision_tower and language_model in the parent class + nn.Module.__init__() + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config, config) + + def sanitize(self, weights): + # ignore mtp weights + weights = {key: value for key, value in weights.items() if "mtp." not in key} + + if self.config.text_config.tie_word_embeddings: + weights.pop("lm_head.weight", None) + + for l in range(self.config.text_config.num_hidden_layers): + prefix = f"model.language_model.layers.{l}.mlp" + # process gate_up_proj [num_experts, 2 * intermediate_size, hidden_size] + gate_up_weight = weights.pop(f"{prefix}.experts.gate_up_proj") + gate_weight, up_weights = mx.split(gate_up_weight, 2, axis=-2) + weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_weight + weights[f"{prefix}.switch_mlp.up_proj.weight"] = up_weights + # down_proj + weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop(f"{prefix}.experts.down_proj") + + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + + sanitized_weights = {} + for key, value in weights.items(): + if "model" in key: + if "model.language_model" in key: + key = key.replace("model.language_model", "language_model.model") + elif "model.visual" in key: + key = key.replace("model.visual", "vision_tower") + elif "lm_head" in key: + key = key.replace("lm_head", "language_model.lm_head") + + if "conv1d.weight" in key and value.shape[-1] != 1: + value = value.moveaxis(2, 1) + if any(key.endswith(sfx) for sfx in norm_keys): + if value.ndim == 1: + value += 1.0 + + sanitized_weights[key] = value + + return sanitized_weights diff --git a/mlx_vlm/models/qwen3_5_moe/vision.py b/mlx_vlm/models/qwen3_5_moe/vision.py new file mode 100644 index 000000000..9b02abaf9 --- /dev/null +++ b/mlx_vlm/models/qwen3_5_moe/vision.py @@ -0,0 +1,4 @@ +from ..qwen3_vl import VisionModel as Qwen3VLVisionModel + +class VisionModel(Qwen3VLVisionModel): + pass \ No newline at end of file diff --git a/mlx_vlm/models/qwen3_vl/vision.py b/mlx_vlm/models/qwen3_vl/vision.py index 2d13c02d5..4d5f63e28 100644 --- a/mlx_vlm/models/qwen3_vl/vision.py +++ b/mlx_vlm/models/qwen3_vl/vision.py @@ -199,7 +199,7 @@ def __init__(self, config: VisionConfig) -> None: self.config = config self.model_type = config.model_type - if self.model_type != "qwen3_vl": + if self.model_type not in ["qwen3_vl", "qwen3_5", "qwen3_5_moe"]: raise ValueError(f"Unsupported model type: {self.model_type}") self.spatial_merge_size = config.spatial_merge_size diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 0e1a8334a..46233a94a 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -41,6 +41,8 @@ class MessageFormat(Enum): "qwen2_5_vl": MessageFormat.LIST_WITH_IMAGE_FIRST, "qwen3_vl": MessageFormat.LIST_WITH_IMAGE_FIRST, "qwen3_vl_moe": MessageFormat.LIST_WITH_IMAGE_FIRST, + "qwen3_5": MessageFormat.LIST_WITH_IMAGE_FIRST, + "qwen3_5_moe": MessageFormat.LIST_WITH_IMAGE_FIRST, "mistral3": MessageFormat.LIST_WITH_IMAGE_FIRST, "glm4v": MessageFormat.LIST_WITH_IMAGE_FIRST, "glm4v_moe": MessageFormat.LIST_WITH_IMAGE_FIRST, @@ -198,6 +200,8 @@ def format_message( "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", + "qwen3_5", + "qwen3_5_moe", ] and kwargs.get("video"): return self._format_video_message(prompt, kwargs) From 93f958a816eedd87144824105588e1360ff00d0a Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Mon, 9 Feb 2026 19:43:56 +0800 Subject: [PATCH 2/2] pre-commit --- mlx_vlm/models/base.py | 7 ++++-- mlx_vlm/models/qwen3_5/config.py | 29 ++++++++++++++++++----- mlx_vlm/models/qwen3_5/language.py | 8 +++---- mlx_vlm/models/qwen3_5/qwen3_5.py | 8 +++---- mlx_vlm/models/qwen3_5/vision.py | 3 ++- mlx_vlm/models/qwen3_5_moe/config.py | 29 +++++++++++++++++++---- mlx_vlm/models/qwen3_5_moe/language.py | 15 +++++------- mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py | 6 +++-- mlx_vlm/models/qwen3_5_moe/vision.py | 3 ++- 9 files changed, 72 insertions(+), 36 deletions(-) diff --git a/mlx_vlm/models/base.py b/mlx_vlm/models/base.py index b60b3a8df..3b6077aa8 100644 --- a/mlx_vlm/models/base.py +++ b/mlx_vlm/models/base.py @@ -6,8 +6,11 @@ import mlx.core as mx import mlx.nn as nn -from mlx_lm.models.base import create_attention_mask, create_ssm_mask, scaled_dot_product_attention - +from mlx_lm.models.base import ( + create_attention_mask, + create_ssm_mask, + scaled_dot_product_attention, +) from PIL import Image diff --git a/mlx_vlm/models/qwen3_5/config.py b/mlx_vlm/models/qwen3_5/config.py index 7094e7d96..64daed826 100644 --- a/mlx_vlm/models/qwen3_5/config.py +++ b/mlx_vlm/models/qwen3_5/config.py @@ -11,8 +11,13 @@ class VisionConfig(Qwen3VLVisionConfig): model_type: str = "qwen3_5" def __post_init__(self): - if self.deepstack_visual_indexes is not None and len(self.deepstack_visual_indexes) > 0: - raise ValueError(f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}") + if ( + self.deepstack_visual_indexes is not None + and len(self.deepstack_visual_indexes) > 0 + ): + raise ValueError( + f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}" + ) self.deepstack_visual_indexes = [] @@ -36,17 +41,30 @@ class TextConfig(BaseModelConfig): attention_bias: bool = False head_dim: Optional[int] = None rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( - default_factory=lambda: {"type": "default", "mrope_section": [11, 11, 10], "rope_theta": 100000, "partial_rotary_factor": 0.25} + default_factory=lambda: { + "type": "default", + "mrope_section": [11, 11, 10], + "rope_theta": 100000, + "partial_rotary_factor": 0.25, + } ) full_attention_interval: int = 4 def __post_init__(self): if self.rope_parameters: # Normalize rope_parameters keys (accept both 'rope_type' and 'type') - if "type" not in self.rope_parameters and "rope_type" in self.rope_parameters: + if ( + "type" not in self.rope_parameters + and "rope_type" in self.rope_parameters + ): self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") - required_keys = {"mrope_section", "type", "rope_theta", "partial_rotary_factor"} + required_keys = { + "mrope_section", + "type", + "rope_theta", + "partial_rotary_factor", + } if not all(key in self.rope_parameters for key in required_keys): raise ValueError(f"rope_parameters must contain keys {required_keys}") @@ -81,4 +99,3 @@ def from_dict(cls, params): if k in inspect.signature(cls).parameters } ) - diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 384490610..2eb2a1429 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1,8 +1,9 @@ -from typing import Optional, Any +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -import numpy as np +from mlx_lm.models.activations import swiglu +from mlx_lm.models.gated_delta import gated_delta_update from ..base import ( LanguageModelOutput, @@ -11,9 +12,6 @@ scaled_dot_product_attention, ) from ..cache import ArraysCache, KVCache -from mlx_lm.models.gated_delta import gated_delta_update -from mlx_lm.models.activations import swiglu - from .config import ModelConfig, TextConfig diff --git a/mlx_vlm/models/qwen3_5/qwen3_5.py b/mlx_vlm/models/qwen3_5/qwen3_5.py index b52581511..8ba6e5b96 100644 --- a/mlx_vlm/models/qwen3_5/qwen3_5.py +++ b/mlx_vlm/models/qwen3_5/qwen3_5.py @@ -1,16 +1,16 @@ +from typing import Optional import mlx.core as mx import mlx.nn as nn -from typing import Optional from ..base import InputEmbeddingsFeatures from ..qwen3_vl import Model as Qwen3VLModel from ..qwen3_vl.qwen3_vl import masked_scatter - from .config import ModelConfig from .language import LanguageModel from .vision import VisionModel + class Model(Qwen3VLModel): def __init__(self, config: ModelConfig): @@ -46,9 +46,7 @@ def get_input_embeddings( inputs_embeds = self.language_model.model.embed_tokens(input_ids) # Get the ouptut hidden states from the vision model - hidden_states, _ = self.vision_tower( - pixel_values, grid_thw - ) + hidden_states, _ = self.vision_tower(pixel_values, grid_thw) # Insert special image tokens in the input_ids inputs_embeds, _ = self.merge_input_ids_with_image_features( diff --git a/mlx_vlm/models/qwen3_5/vision.py b/mlx_vlm/models/qwen3_5/vision.py index 9b02abaf9..1e4a472c2 100644 --- a/mlx_vlm/models/qwen3_5/vision.py +++ b/mlx_vlm/models/qwen3_5/vision.py @@ -1,4 +1,5 @@ from ..qwen3_vl import VisionModel as Qwen3VLVisionModel + class VisionModel(Qwen3VLVisionModel): - pass \ No newline at end of file + pass diff --git a/mlx_vlm/models/qwen3_5_moe/config.py b/mlx_vlm/models/qwen3_5_moe/config.py index 368f22f07..1aab30bb6 100644 --- a/mlx_vlm/models/qwen3_5_moe/config.py +++ b/mlx_vlm/models/qwen3_5_moe/config.py @@ -11,10 +11,16 @@ class VisionConfig(Qwen3VLVisionConfig): model_type: str = "qwen3_5_moe" def __post_init__(self): - if self.deepstack_visual_indexes is not None and len(self.deepstack_visual_indexes) > 0: - raise ValueError(f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}") + if ( + self.deepstack_visual_indexes is not None + and len(self.deepstack_visual_indexes) > 0 + ): + raise ValueError( + f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}" + ) self.deepstack_visual_indexes = [] + @dataclass class TextConfig(BaseModelConfig): model_type: str @@ -38,17 +44,30 @@ class TextConfig(BaseModelConfig): attention_bias: bool = False head_dim: Optional[int] = None rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( - default_factory=lambda: {"type": "default", "mrope_section": [11, 11, 10], "rope_theta": 100000, "partial_rotary_factor": 0.25} + default_factory=lambda: { + "type": "default", + "mrope_section": [11, 11, 10], + "rope_theta": 100000, + "partial_rotary_factor": 0.25, + } ) full_attention_interval: int = 4 def __post_init__(self): if self.rope_parameters: # Normalize rope_parameters keys (accept both 'rope_type' and 'type') - if "type" not in self.rope_parameters and "rope_type" in self.rope_parameters: + if ( + "type" not in self.rope_parameters + and "rope_type" in self.rope_parameters + ): self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") - required_keys = {"mrope_section", "type", "rope_theta", "partial_rotary_factor"} + required_keys = { + "mrope_section", + "type", + "rope_theta", + "partial_rotary_factor", + } if not all(key in self.rope_parameters for key in required_keys): raise ValueError(f"rope_parameters must contain keys {required_keys}") diff --git a/mlx_vlm/models/qwen3_5_moe/language.py b/mlx_vlm/models/qwen3_5_moe/language.py index 40ea49d5d..620b1b489 100644 --- a/mlx_vlm/models/qwen3_5_moe/language.py +++ b/mlx_vlm/models/qwen3_5_moe/language.py @@ -1,17 +1,14 @@ -from typing import Optional, Any +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn - from mlx_lm.models.switch_layers import SwitchGLU -from ..qwen3_5.language import ( - LanguageModel as Qwen3_5LanguageModel, - Qwen3_5Attention as Qwen3_5MoeAttention, - Qwen3_5GatedDeltaNet as Qwen3_5MoeGatedDeltaNet, - Qwen3_5MLP as Qwen3_5MoeMLP, - Qwen3_5Model, -) +from ..qwen3_5.language import LanguageModel as Qwen3_5LanguageModel +from ..qwen3_5.language import Qwen3_5Attention as Qwen3_5MoeAttention +from ..qwen3_5.language import Qwen3_5GatedDeltaNet as Qwen3_5MoeGatedDeltaNet +from ..qwen3_5.language import Qwen3_5MLP as Qwen3_5MoeMLP +from ..qwen3_5.language import Qwen3_5Model from .config import ModelConfig, TextConfig diff --git a/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py b/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py index 374f8f949..231aadb80 100644 --- a/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py +++ b/mlx_vlm/models/qwen3_5_moe/qwen3_5_moe.py @@ -2,11 +2,11 @@ import mlx.nn as nn from ..qwen3_5 import Model as Qwen3_5Model - from .config import ModelConfig from .language import LanguageModel from .vision import VisionModel + class Model(Qwen3_5Model): def __init__(self, config: ModelConfig): @@ -31,7 +31,9 @@ def sanitize(self, weights): weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_weight weights[f"{prefix}.switch_mlp.up_proj.weight"] = up_weights # down_proj - weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop(f"{prefix}.experts.down_proj") + weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( + f"{prefix}.experts.down_proj" + ) norm_keys = ( ".input_layernorm.weight", diff --git a/mlx_vlm/models/qwen3_5_moe/vision.py b/mlx_vlm/models/qwen3_5_moe/vision.py index 9b02abaf9..1e4a472c2 100644 --- a/mlx_vlm/models/qwen3_5_moe/vision.py +++ b/mlx_vlm/models/qwen3_5_moe/vision.py @@ -1,4 +1,5 @@ from ..qwen3_vl import VisionModel as Qwen3VLVisionModel + class VisionModel(Qwen3VLVisionModel): - pass \ No newline at end of file + pass