From 63fef0e3420721f5a1e3d66dd8d1e67dae2e95cb Mon Sep 17 00:00:00 2001 From: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Date: Wed, 11 Feb 2026 01:11:15 +0800 Subject: [PATCH 1/5] support text-only qwen3.5 series Co-authored-by: johnmai-dev --- mlx_lm/models/qwen3_5.py | 376 +++++++++++++++++++++++++++++++++++ mlx_lm/models/qwen3_5_moe.py | 55 +++++ 2 files changed, 431 insertions(+) create mode 100644 mlx_lm/models/qwen3_5.py create mode 100644 mlx_lm/models/qwen3_5_moe.py diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py new file mode 100644 index 000000000..984f1aa7e --- /dev/null +++ b/mlx_lm/models/qwen3_5.py @@ -0,0 +1,376 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + +from .base import ( + BaseModelArgs, + create_attention_mask, + create_ssm_mask, +) +from .cache import ArraysCache, KVCache +from .gated_delta import gated_delta_update +from .qwen3_next import Qwen3NextAttention as Attention +from .qwen3_next import Qwen3NextMLP as MLP +from .qwen3_next import Qwen3NextRMSNormGated as RMSNormGated +from .qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock + + +@dataclass +class TextModelArgs(BaseModelArgs): + model_type: str = "" + hidden_size: int = 4096 + intermediate_size: int = 14336 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-6 + vocab_size: int = 151936 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + linear_num_value_heads: int = 64 + linear_num_key_heads: int = 16 + linear_key_head_dim: int = 192 + linear_value_head_dim: int = 128 + linear_conv_kernel_dim: int = 4 + tie_word_embeddings: bool = False + attention_bias: bool = False + head_dim: Optional[int] = None + full_attention_interval: int = 4 + + # MoE fields (optional, for Qwen3_5MoeForConditionalGeneration) + num_experts: int = 0 + num_experts_per_tok: int = 0 + decoder_sparse_step: int = 1 + shared_expert_intermediate_size: int = 0 + moe_intermediate_size: int = 0 + norm_topk_prob: bool = True + + # Rope parameters + rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( + default_factory=lambda: { + "type": "default", + "mrope_section": [11, 11, 10], + "rope_theta": 100000, + "partial_rotary_factor": 0.25, + } + ) + + # Derived from rope_parameters (set in __post_init__) + partial_rotary_factor: float = 0.25 + rope_theta: float = 100000.0 + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + + if self.rope_parameters: + if ( + "type" not in self.rope_parameters + and "rope_type" in self.rope_parameters + ): + self.rope_parameters["type"] = self.rope_parameters.pop("rope_type") + + self.partial_rotary_factor = self.rope_parameters.get( + "partial_rotary_factor", 0.25 + ) + self.rope_theta = self.rope_parameters.get("rope_theta", 100000.0) + self.rope_scaling = self.rope_parameters + + +class GatedDeltaNet(nn.Module): + def __init__(self, config: TextModelArgs): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + if self.num_v_heads % self.num_k_heads != 0: + raise ValueError( + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" + ) + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=0, + ) + + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.dt_bias = mx.ones(self.num_v_heads) + + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) + self.A_log = mx.log(A) + + self.norm = RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, S, _ = inputs.shape + + qkv = self.in_proj_qkv(inputs) + z = self.in_proj_z(inputs).reshape(B, S, self.num_v_heads, self.head_v_dim) + b = self.in_proj_b(inputs) + a = self.in_proj_a(inputs) + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, + ) + + if mask is not None: + qkv = mx.where(mask[..., None], qkv, 0) + conv_input = mx.concatenate([conv_state, qkv], axis=1) + if cache is not None: + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) + + q, k, v = [ + t.reshape(B, S, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + + state = cache[1] if cache else None + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, state = gated_delta_update( + q, + k, + v, + a, + b, + self.A_log, + self.dt_bias, + state, + mask, + use_kernel=not self.training, + ) + + if cache is not None: + cache[1] = state + + out = self.norm(out, z) + return self.out_proj(out.reshape(B, S, -1)) + + +class DecoderLayer(nn.Module): + def __init__(self, args: TextModelArgs, layer_idx: int): + super().__init__() + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: + self.linear_attn = GatedDeltaNet(args) + else: + self.self_attn = Attention(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + if args.num_experts > 0: + self.mlp = SparseMoeBlock(args) + else: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + if self.is_linear: + r = self.linear_attn(self.input_layernorm(x), mask, cache) + else: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + out = h + self.mlp(self.post_attention_layernorm(h)) + return out + + +class Qwen3_5TextModel(nn.Module): + def __init__(self, args: TextModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + DecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.ssm_idx = 0 + self.fa_idx = args.full_attention_interval - 1 + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + if input_embeddings is not None: + hidden_states = input_embeddings + else: + hidden_states = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx]) + + for layer, c in zip(self.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=c) + + return self.norm(hidden_states) + + +class TextModel(nn.Module): + def __init__(self, args: TextModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen3_5TextModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + out = self.model(inputs, cache, input_embeddings=input_embeddings) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + + def sanitize(self, weights): + weights = {k: v for k, v in weights.items() if "mtp." not in k} + + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + if any(k.endswith(sfx) for sfx in norm_keys): + if v.ndim == 1: + weights[k] = v + 1.0 + return weights + + @property + def quant_predicate(self): + if self.args.num_experts <= 0: + return None + + def predicate(path, _): + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): + return {"group_size": 64, "bits": 8} + return True + + return predicate + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + text_config: dict + + @classmethod + def from_dict(cls, params): + if "text_config" not in params: + return cls(model_type=params["model_type"], text_config=params) + return super().from_dict(params) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.language_model = TextModel(TextModelArgs.from_dict(args.text_config)) + + def __call__( + self, + inputs: mx.array, + cache=None, + input_embeddings: Optional[mx.array] = None, + ): + return self.language_model( + inputs, cache=cache, input_embeddings=input_embeddings + ) + + def sanitize(self, weights): + weights = tree_unflatten(list(weights.items())) + weights = dict(tree_flatten(weights)) + + sanitized = {} + for key, value in weights.items(): + if key.startswith("model.visual"): + continue + if key.startswith("model.language_model"): + key = key.replace("model.language_model", "language_model.model") + else: + key = "language_model." + key + sanitized[key] = value + return self.language_model.sanitize(sanitized) + + @property + def layers(self): + return self.language_model.model.layers + + def make_cache(self): + return self.language_model.make_cache() + + @property + def quant_predicate(self): + return self.language_model.quant_predicate diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py new file mode 100644 index 000000000..3913e96a6 --- /dev/null +++ b/mlx_lm/models/qwen3_5_moe.py @@ -0,0 +1,55 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass + +from mlx.utils import tree_flatten, tree_unflatten + +from .base import BaseModelArgs +from .qwen3_5 import Model as Qwen3_5Model + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + text_config: dict + + @classmethod + def from_dict(cls, params): + if "text_config" not in params: + return cls(model_type=params["model_type"], text_config=params) + return super().from_dict(params) + + +class Model(Qwen3_5Model): + + def sanitize(self, weights): + weights = tree_unflatten(list(weights.items())) + weights = dict(tree_flatten(weights)) + + new_weights = {} + for key, value in weights.items(): + if key.startswith("model.visual"): + continue + if key.startswith("model.language_model"): + key = key.replace("model.language_model", "language_model.model") + else: + key = "language_model." + key + new_weights[key] = value + + for l in range(self.language_model.args.num_hidden_layers): + prefix = f"language_model.model.layers.{l}.mlp" + gate_up_key = f"{prefix}.experts.gate_up_proj" + if gate_up_key in new_weights: + gate_up = new_weights.pop(gate_up_key) + mid = gate_up.shape[-2] // 2 + new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ + ..., :mid, : + ] + new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ + ..., mid:, : + ] + new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( + f"{prefix}.experts.down_proj" + ) + + return self.language_model.sanitize(new_weights) From ae8cf11e3bce697ac62089013289b093003c3fc9 Mon Sep 17 00:00:00 2001 From: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Date: Wed, 11 Feb 2026 01:21:22 +0800 Subject: [PATCH 2/5] add test --- tests/test_models.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 46f8ad1ed..072ae818a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2121,6 +2121,47 @@ def test_all_models(self): "partial_rotary_factor": 0.5, "max_position_embeddings": 1000, }, + { + "model_type": "qwen3_5", + "hidden_size": 128, + "num_hidden_layers": 4, + "intermediate_size": 128, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "vocab_size": 1000, + "linear_num_value_heads": 4, + "linear_num_key_heads": 4, + "linear_key_head_dim": 32, + "linear_value_head_dim": 32, + "linear_conv_kernel_dim": 3, + "rms_norm_eps": 1e-5, + "head_dim": 64, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 1000, + }, + { + "model_type": "qwen3_5_moe", + "hidden_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "vocab_size": 1000, + "linear_num_value_heads": 4, + "linear_num_key_heads": 4, + "linear_key_head_dim": 32, + "linear_value_head_dim": 32, + "linear_conv_kernel_dim": 3, + "num_experts": 4, + "num_experts_per_tok": 2, + "shared_expert_intermediate_size": 128, + "moe_intermediate_size": 128, + "rms_norm_eps": 1e-5, + "head_dim": 64, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 1000, + }, { "model_type": "kimi_linear", "vocab_size": 1000, From 3b54f0d500bfe91a40268ef48793a91a12c5812d Mon Sep 17 00:00:00 2001 From: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:01:06 +0800 Subject: [PATCH 3/5] fix sanitize and add test --- mlx_lm/models/qwen3_5.py | 7 ++++- mlx_lm/models/qwen3_5_moe.py | 2 ++ tests/test_models.py | 53 ++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 984f1aa7e..ebe919640 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -287,6 +287,9 @@ def make_cache(self): return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] def sanitize(self, weights): + should_shift_norm_weights = any("mtp." in k for k in weights) or any( + "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() + ) weights = {k: v for k, v in weights.items() if "mtp." not in k} if self.args.tie_word_embeddings: @@ -302,7 +305,7 @@ def sanitize(self, weights): for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) - if any(k.endswith(sfx) for sfx in norm_keys): + if should_shift_norm_weights and any(k.endswith(sfx) for sfx in norm_keys): if v.ndim == 1: weights[k] = v + 1.0 return weights @@ -359,6 +362,8 @@ def sanitize(self, weights): continue if key.startswith("model.language_model"): key = key.replace("model.language_model", "language_model.model") + elif key.startswith("language_model."): + pass else: key = "language_model." + key sanitized[key] = value diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 3913e96a6..c94b3fb26 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -32,6 +32,8 @@ def sanitize(self, weights): continue if key.startswith("model.language_model"): key = key.replace("model.language_model", "language_model.model") + elif key.startswith("language_model."): + pass else: key = "language_model." + key new_weights[key] = value diff --git a/tests/test_models.py b/tests/test_models.py index 072ae818a..ebd98e4f1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -530,6 +530,59 @@ def test_qwen3(self): self.model_test_runner( model, args.model_type, args.vocab_size, args.num_hidden_layers ) + + def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self): + text_config = { + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "num_key_value_heads": 1, + "rms_norm_eps": 1e-5, + "vocab_size": 32, + "linear_num_value_heads": 1, + "linear_num_key_heads": 1, + "linear_key_head_dim": 4, + "linear_value_head_dim": 4, + "linear_conv_kernel_dim": 1, + "full_attention_interval": 1, + "tie_word_embeddings": False, + "max_position_embeddings": 64, + } + hf_norm_key = "model.language_model.layers.0.input_layernorm.weight" + mlx_norm_key = "language_model.model.layers.0.input_layernorm.weight" + + for model_type, hf_mtp_key in ( + ("qwen3_5", "mtp.fc.weights"), + ("qwen3_5_moe", "mtp.fc.weight"), + ): + module = importlib.import_module(f"mlx_lm.models.{model_type}") + args = module.ModelArgs.from_dict( + { + "model_type": model_type, + "text_config": {"model_type": model_type, **text_config}, + } + ) + model = module.Model(args) + + base = mx.arange(8, dtype=mx.float32) + + # Simulate convert sanitize on HF-style keys. + converted = model.sanitize( + { + hf_norm_key: base, + hf_mtp_key: mx.zeros((1,), dtype=mx.float32), + } + ) + self.assertIn(mlx_norm_key, converted) + self.assertTrue(mx.array_equal(converted[mlx_norm_key], base + 1.0)) + self.assertFalse(any("mtp." in k for k in converted)) + + # Simulate load sanitize on already-converted keys. + loaded = model.sanitize(converted) + self.assertTrue( + mx.array_equal(loaded[mlx_norm_key], converted[mlx_norm_key]) + ) def test_qwen2_moe(self): from mlx_lm.models import qwen2_moe From 3f20611aec15d261c67760abc67f6c043530d5f9 Mon Sep 17 00:00:00 2001 From: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:07:32 +0800 Subject: [PATCH 4/5] make it more readable --- mlx_lm/models/qwen3_5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index ebe919640..d2925a010 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -287,9 +287,11 @@ def make_cache(self): return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] def sanitize(self, weights): - should_shift_norm_weights = any("mtp." in k for k in weights) or any( + has_mtp_weights = any("mtp." in k for k in weights) + has_unsanitized_conv1d = any( "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() ) + should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d weights = {k: v for k, v in weights.items() if "mtp." not in k} if self.args.tie_word_embeddings: From 2dc22b664a5ea40efa7f5c11ca49b7d6f239077d Mon Sep 17 00:00:00 2001 From: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Date: Wed, 11 Feb 2026 23:55:54 +0800 Subject: [PATCH 5/5] fix lint --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index ebd98e4f1..e07569f70 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -530,7 +530,7 @@ def test_qwen3(self): self.model_test_runner( model, args.model_type, args.vocab_size, args.num_hidden_layers ) - + def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self): text_config = { "hidden_size": 8,