diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index b452d5b71..1f581106e 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,7 +14,7 @@ OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's `Mamba v1` inclusionAI's `Bailing MoE e.g. Ling-family`, `Bailing MoE Linear e.g. Ling-Linear-family`, Klear team - Kuaishou Technology's `Klear`, AI21 Lab's `Jamba` IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, Nikity's `Lille130m`, -Alibaba Qwen's `Qwen3Next`, Tele-AI's `TeleChat3`, and Allenai's `OLMoE` and `Olmo 3`; +Alibaba Qwen's `Qwen3Next` and `Qwen3.5 MoE`, Tele-AI's `TeleChat3`, and Allenai's `OLMoE` and `Olmo 3`; Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; @@ -26,4 +26,4 @@ Added support for the following other features: MoonshotAI's `Kimi-Linear`, LiquidAI's `LFM2` and `LFM2 MoE`, Google DeepMind's `Gemma 3`, TII's `Falcon H1` and InterLM's `InternLM 2.5`. - Ivan Fioravanti: Added support for the following architectures: - ServiceNow-AI's `Apriel 1.5`, Tencent's `Hunyuan Dense V1` and `Hunyuan MoE V1`. \ No newline at end of file + ServiceNow-AI's `Apriel 1.5`, Tencent's `Hunyuan Dense V1` and `Hunyuan MoE V1`. diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py new file mode 100644 index 000000000..8da242d1a --- /dev/null +++ b/mlx_lm/models/qwen3_5_moe.py @@ -0,0 +1,62 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + +from . import qwen3_5_moe_text +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + text_config: dict + + @classmethod + def from_dict(cls, params): + if "text_config" not in params: + return cls(model_type=params["model_type"], text_config=params) + return super().from_dict(params) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.language_model = qwen3_5_moe_text.Model( + qwen3_5_moe_text.ModelArgs.from_dict(args.text_config) + ) + + def __call__( + self, + inputs: mx.array, + cache=None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + return self.language_model( + inputs, cache=cache, input_embeddings=input_embeddings + ) + + def sanitize(self, weights): + weights = tree_unflatten(list(weights.items())) + weights.pop("vision_tower", None) + weights = dict(tree_flatten(weights)) + + sanitized = {} + for key, value in weights.items(): + if not key.startswith("language_model."): + key = "language_model." + key + sanitized[key] = value + return sanitized + + @property + def layers(self): + return self.language_model.model.layers + + def make_cache(self): + return self.language_model.make_cache() diff --git a/mlx_lm/models/qwen3_5_moe_text.py b/mlx_lm/models/qwen3_5_moe_text.py new file mode 100644 index 000000000..e4793ed3e --- /dev/null +++ b/mlx_lm/models/qwen3_5_moe_text.py @@ -0,0 +1,343 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from qwen3_next import ( + Qwen3NextAttention, + Qwen3NextMLP, + Qwen3NextRMSNormGated, +) + +from .base import ( + BaseModelArgs, + create_attention_mask, + create_ssm_mask, +) +from .cache import ArraysCache, KVCache +from .gated_delta import gated_delta_update +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + linear_num_value_heads: int + linear_num_key_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int + num_experts: int + num_experts_per_tok: int + decoder_sparse_step: int + shared_expert_intermediate_size: int + mlp_only_layers: List[int] + moe_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + rope_theta: float + partial_rotary_factor: float + max_position_embeddings: int + head_dim: Optional[int] = None + norm_topk_prob: bool = False + tie_word_embeddings: bool = False + attention_bias: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + layer_types: Optional[List[str]] = None + full_attention_interval: int = 4 + + def __post_init__(self): + if self.layer_types is None: + self.layer_types = [ + ( + "linear_attention" + if (i + 1) % self.full_attention_interval != 0 + else "full_attention" + ) + for i in range(self.num_hidden_layers) + ] + + +class Qwen3_5MoEGatedDeltaNet(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + if self.num_v_heads % self.num_k_heads != 0: + raise ValueError( + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" + ) + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=0, + ) + + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.dt_bias = mx.ones(self.num_v_heads) + + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) + self.A_log = mx.log(A) + + self.norm = Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, S, _ = inputs.shape + + qkv = self.in_proj_qkv(inputs) + z = self.in_proj_z(inputs).reshape(B, S, self.num_v_heads, self.head_v_dim) + b = self.in_proj_b(inputs) + a = self.in_proj_a(inputs) + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, + ) + + if mask is not None: + qkv = mx.where(mask[..., None], qkv, 0) + conv_input = mx.concatenate([conv_state, qkv], axis=1) + if cache is not None: + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) + + q, k, v = mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1) + q = q.reshape(B, S, self.num_k_heads, self.head_k_dim) + k = k.reshape(B, S, self.num_k_heads, self.head_k_dim) + v = v.reshape(B, S, self.num_v_heads, self.head_v_dim) + + state = cache[1] if cache else None + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, state = gated_delta_update( + q, + k, + v, + a, + b, + self.A_log, + self.dt_bias, + state, + mask, + use_kernel=not self.training, + ) + + if cache is not None: + cache[1] = state + + out = self.norm(out, z) + return self.out_proj(out.reshape(B, S, -1)) + + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + self.norm_topk_prob = args.norm_topk_prob + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + + self.gate = nn.Linear(dim, num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(dim, 1, bias=False) + + def __call__( + self, + x: mx.array, + ) -> mx.array: + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] + scores = mx.take_along_axis(gates, inds, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + shared_y = self.shared_expert(x) + shared_y = mx.sigmoid(self.shared_expert_gate(x)) * shared_y + + return y + shared_y + + +class Qwen3MoENextDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: + self.linear_attn = Qwen3_5MoEGatedDeltaNet(args) + else: + self.self_attn = Qwen3NextAttention(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + if (layer_idx not in args.mlp_only_layers) and ( + args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock(args) + else: + self.mlp = Qwen3NextMLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + if self.is_linear: + r = self.linear_attn(self.input_layernorm(x), mask, cache) + else: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + out = h + self.mlp(self.post_attention_layernorm(h)) + return out + + +class Qwen3NextModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen3MoENextDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.ssm_idx = 0 + self.fa_idx = args.full_attention_interval - 1 + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + hidden_states = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx]) + + for layer, c in zip(self.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=c) + + return self.norm(hidden_states) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen3NextModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + weights = {key: value for key, value in weights.items() if "mtp." not in key} + + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}.mlp" + for n in ["up_proj", "down_proj", "gate_proj"]: + to_join = [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + if any(k.endswith(sfx) for sfx in norm_keys): + if v.ndim == 1: + weights[k] = v + 1.0 + return weights + + @property + def quant_predicate(self): + def predicate(path, _): + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): + return {"group_size": 64, "bits": 8} + return True + + return predicate diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 47c9971bd..fa23b46b3 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -45,7 +45,7 @@ class ModelArgs(BaseModelArgs): rope_theta: float partial_rotary_factor: float max_position_embeddings: int - head_dim: int + head_dim: Optional[int] = None norm_topk_prob: bool = False tie_word_embeddings: bool = False attention_bias: bool = False @@ -73,7 +73,7 @@ def __init__(self, args: ModelArgs): super().__init__() self.num_key_value_heads = args.num_key_value_heads self.num_attention_heads = args.num_attention_heads - self.head_dim = args.head_dim + self.head_dim = args.head_dim or (args.hidden_size // args.num_attention_heads) self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear(