diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index a2dd8cbd7..63772aaf0 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,11 +13,11 @@ MLX LM was developed with contributions from the following individuals: THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's - `Apertus`, Nikity's `Lille130m`, and Allenai's `OLMoE`; Added support for the - following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` + `Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE`; + Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`; + Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers - to choose for training`, and `reporting training metrics to WandB (Weights & - Biases)`. + to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`. - Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index eea1f42c6..d94758a7a 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -127,7 +127,7 @@ def __call__( gates = mx.softmax(gates, axis=-1, precise=True) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] scores = mx.take_along_axis(gates, inds, axis=-1) if self.norm_topk_prob: scores /= mx.sum(scores, axis=-1, keepdims=True) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py new file mode 100644 index 000000000..19ea7c2d2 --- /dev/null +++ b/mlx_lm/models/qwen3_next.py @@ -0,0 +1,496 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, MambaCache +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + linear_num_value_heads: int + linear_num_key_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int + num_experts: int + num_experts_per_tok: int + decoder_sparse_step: int + shared_expert_intermediate_size: int + mlp_only_layers: List[int] + moe_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + rope_theta: float + partial_rotary_factor: float + max_position_embeddings: int + norm_topk_prob: bool = False + tie_word_embeddings: bool = False + attention_bias: bool = False + head_dim: Optional[int] = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + full_attention_interval: int = 4 + + +@partial(mx.compile, shapeless=True) +def compute_g(A_log, a, dt_bias): + return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias)).astype( + A_log.dtype + ) + + +def recurrent_gated_delta_rule( + query: mx.array, + key: mx.array, + value: mx.array, + a: mx.array, + b: mx.array, + A_log: mx.array, + dt_bias: mx.array, + state: mx.array, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[mx.array, mx.array]: + B, S, Hk, Dk = key.shape + Hv, Dv = value.shape[2:] + inv_scale = Dk**-0.5 + if use_qk_l2norm_in_kernel: + query = (inv_scale**2) * mx.fast.rms_norm(query, None, 1e-6) + key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) + else: + query = inv_scale * query + + input_type = query.dtype + if (repeat_factor := Hv // Hk) > 1: + query = mx.repeat(query, repeat_factor, 2) + key = mx.repeat(key, repeat_factor, 2) + + beta = mx.sigmoid(b) + g = compute_g(A_log, a, dt_bias) + outs = [] + for i in range(S): + kv_mem = (state * key[:, i, :, :, None]).sum(axis=-2) + delta = (value[:, i] - kv_mem) * beta[:, i, :, None] + state += key[:, i, :, :, None] * delta[..., None, :] + out = (state * query[:, i, :, :, None]).sum(axis=-2) + outs.append(out) + out = mx.stack(outs, axis=1).astype(input_type) + return out, state + + +class Qwen3NextRMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = mx.ones(hidden_size) + + def __call__( + self, hidden_states: mx.array, gate: mx.array | None = None + ) -> mx.array: + x = mx.fast.rms_norm(hidden_states, self.weight, self.eps) + if gate is not None: + x = x * nn.silu(gate) + return x + + +class Qwen3NextAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_key_value_heads = args.num_key_value_heads + self.num_attention_heads = args.num_attention_heads + self.head_dim = args.head_dim + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear( + args.hidden_size, + self.num_attention_heads * self.head_dim * 2, + bias=args.attention_bias, + ) + self.k_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.v_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + args.hidden_size, + bias=args.attention_bias, + ) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + + self.rope = initialize_rope( + int(self.head_dim * args.partial_rotary_factor), + base=args.rope_theta, + traditional=False, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1), 2, axis=-1 + ) + gate = gate.reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output * mx.sigmoid(gate)) + + +class Qwen3NextMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen3NextGatedDeltaNet(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + if self.num_v_heads % self.num_k_heads != 0: + raise ValueError( + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" + ) + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=0, + ) + + self.in_proj_qkvz = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim * 2, bias=False + ) + self.in_proj_ba = nn.Linear(self.hidden_size, self.num_v_heads * 2, bias=False) + + self.dt_bias = mx.ones(self.num_v_heads) + + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) + self.A_log = mx.log(A) + + self.norm = Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def fix_query_key_value_ordering( + self, mixed_qkvz: mx.array, mixed_ba: mx.array + ) -> mx.array: + nk, dn, nv, dv = ( + self.num_k_heads, + self.head_k_dim, + self.num_v_heads, + self.head_v_dim, + ) + mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nk, -1) + mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nk, -1) + q, k, v, z = mx.split(mixed_qkvz, [dn, 2 * dn, 2 * dn + nv // nk * dv], axis=-1) + b, a = mx.split(mixed_ba, [nv // nk], axis=-1) + return ( + q, + k, + v.reshape(*v.shape[:2], -1, dv), + z.reshape(*z.shape[:2], -1, dv), + b.reshape(*b.shape[:2], nv), + a.reshape(*a.shape[:2], nv), + ) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, S, _ = inputs.shape + q, k, v, z, b, a = self.fix_query_key_value_ordering( + self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) + ) + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, + ) + + mixed_qkv = mx.concatenate( + [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], axis=-1 + ) + conv_input = mx.concatenate([conv_state, mixed_qkv], axis=1) + if cache is not None: + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) + + q, k, v = [ + t.reshape(B, S, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + + if cache is not None and cache[1] is not None: + state = cache[1] + else: + state = mx.zeros( + (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=inputs.dtype, + ) + + out, new_state = recurrent_gated_delta_rule( + q, + k, + v, + a, + b, + self.A_log, + self.dt_bias, + state, + use_qk_l2norm_in_kernel=True, + ) + + if cache is not None: + cache[1] = state + + out = self.norm(out, z) + return self.out_proj(out.reshape(B, S, -1)) + + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + self.norm_topk_prob = args.norm_topk_prob + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + + self.gate = nn.Linear(dim, num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(dim, 1, bias=False) + + def __call__( + self, + x: mx.array, + ) -> mx.array: + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] + scores = mx.take_along_axis(gates, inds, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + shared_y = self.shared_expert(x) + shared_y = mx.sigmoid(self.shared_expert_gate(x)) * shared_y + + return y + shared_y + + +class Qwen3NextDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: + self.linear_attn = Qwen3NextGatedDeltaNet(args) + else: + self.self_attn = Qwen3NextAttention(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + if (layer_idx not in args.mlp_only_layers) and ( + args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock(args) + else: + self.mlp = Qwen3NextMLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + if self.is_linear: + r = self.linear_attn(self.input_layernorm(x), mask, cache) + else: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + out = h + self.mlp(self.post_attention_layernorm(h)) + return out + + +class Qwen3NextModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen3NextDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fa_idx = args.full_attention_interval - 1 + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + hidden_states = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(hidden_states, cache[self.fa_idx]) + + for layer, c in zip(self.layers, cache): + hidden_states = layer(hidden_states, mask=mask, cache=c) + + return self.norm(hidden_states) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen3NextModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + return [MambaCache() if l.is_linear else KVCache() for l in self.layers] + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + weights = {key: value for key, value in weights.items() if "mtp." not in key} + + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}.mlp" + for n in ["up_proj", "down_proj", "gate_proj"]: + to_join = [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + "model.norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + if any(k.endswith(sfx) for sfx in norm_keys): + if v.ndim == 1: + weights[k] = v + 1.0 + return weights + + @property + def quant_predicate(self): + def predicate(path, _): + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): + return {"group_size": 64, "bits": 8} + return True + + return predicate diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index 4b3e5d8f7..82988998a 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -128,13 +128,14 @@ def to_lora(layer): "longcat_flash", "seed_oss", "apertus", + "qwen3_next", "Klear", "lille-130m", }: keys = {"self_attn.q_proj", "self_attn.v_proj"} if model.model_type in ["mixtral", "phimoe"]: keys.add("block_sparse_moe.gate") - if model.model_type == "qwen2_moe": + if model.model_type in ["qwen2_moe", "qwen3_next"]: keys.add("mlp.gate") keys.add("mlp.shared_expert_gate") if model.model_type in ["olmoe", "qwen3_moe", "dots1", "Klear"]: @@ -147,6 +148,15 @@ def to_lora(layer): keys.add("feed_forward.gate_proj") keys.add("feed_forward.up_proj") keys.add("feed_forward.down_proj") + elif model.model_type == "qwen3_next": + keys.add("linear_attn.in_proj_qkvz") + keys.add("linear_attn.out_proj") + keys.add("linear_attn.in_proj_ba") + keys.add("linear_attn.dt_bias") + keys.add("self_attn.q_proj") + keys.add("self_attn.k_proj") + keys.add("self_attn.v_proj") + keys.add("self_attn.o_proj") elif model.model_type == "gpt_bigcode": keys = {"attn.c_attn"} elif model.model_type == "gpt2":