From 345a2fd767c66afaa71d907a30b9ab2ee20623d7 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:21:14 +0200 Subject: [PATCH 01/44] in. com. --- mlx_lm/models/qwen3_next.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 mlx_lm/models/qwen3_next.py diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py new file mode 100644 index 000000000..e69de29bb From 9d73b4388b5ec38ae095c99cf779dc98596197a5 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:32:47 +0200 Subject: [PATCH 02/44] adding attention + gated rms norm --- mlx_lm/models/qwen3_next.py | 104 ++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index e69de29bb..737dffb30 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -0,0 +1,104 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_experts: int + num_experts_per_tok: int + decoder_sparse_step: int + mlp_only_layers: List[int] + moe_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + rope_theta: float + tie_word_embeddings: bool + max_position_embeddings: int + norm_topk_prob: bool + attention_bias: bool + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = mx.ones(hidden_size) + + def __call__(self, hidden_states: mx.array, gate: mx.array = None) -> mx.array: + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + return mx.fast.rms_norm(hidden_states, self.weight, self.eps) + + +class Qwen3NextAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_key_value_heads = args.num_key_value_heads + self.num_attention_heads = args.num_attention_heads + self.head_dim = args.hidden_size // self.num_attention_heads + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * self.head_dim * 2, bias=args.attention_bias) + self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, args.hidden_size, bias=False) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + + self.rope = nn.RoPE( + self.head_dim, + traditional=False, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split(q_proj_output.reshape(B, L, self.num_attention_heads, -1, 2), 2, axis=-1) + queries = queries.squeeze(-1) + gate = gate.squeeze(-1).reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + output = output * mx.sigmoid(gate) + + return self.o_proj(output) \ No newline at end of file From ae3ba82f4f5ff934490ed4496c780826f95d09c3 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:36:58 +0200 Subject: [PATCH 03/44] adding Qwen3NextDecoderLayer --- mlx_lm/models/qwen3_next.py | 46 ++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 737dffb30..02166b1b3 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -30,6 +30,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int norm_topk_prob: bool attention_bias: bool + layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -101,4 +102,47 @@ def __call__( output = output * mx.sigmoid(gate) - return self.o_proj(output) \ No newline at end of file + return self.o_proj(output) + + +class Qwen3NextDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + + # token mixer + self.layer_type = args.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet(args, layer_idx) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention(args, layer_idx) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + if (layer_idx not in args.mlp_only_layers) and ( + args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock(args) + else: + self.mlp = Qwen3NextMLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + if self.layer_type == "linear_attention": + r = self.linear_attn(self.input_layernorm(x), mask, cache) + elif self.layer_type == "full_attention": + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + if isinstance(r, tuple): + r, _ = r + out = h + r + return out \ No newline at end of file From eeb9e22e98619c68a92036a50f99fc48eaf0d891 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:38:10 +0200 Subject: [PATCH 04/44] adding Qwen3NextModel --- mlx_lm/models/qwen3_next.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 02166b1b3..dae9034c5 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -108,9 +108,6 @@ def __call__( class Qwen3NextDecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() - self.hidden_size = args.hidden_size - - # token mixer self.layer_type = args.layer_types[layer_idx] if self.layer_type == "linear_attention": self.linear_attn = Qwen3NextGatedDeltaNet(args, layer_idx) @@ -145,4 +142,34 @@ def __call__( if isinstance(r, tuple): r, _ = r out = h + r - return out \ No newline at end of file + return out + + +class Qwen3NextModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen3NextDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) \ No newline at end of file From 9e6268871bd9304c0f1f920b55cae5a5ccea40e3 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:40:59 +0200 Subject: [PATCH 05/44] adding Model --- mlx_lm/models/qwen3_next.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index dae9034c5..d339b0902 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -165,6 +165,8 @@ def __call__( if mask is None: mask = create_attention_mask(h, cache) + + # TODO add linear mask if cache is None: cache = [None] * len(self.layers) @@ -172,4 +174,26 @@ def __call__( for layer, c in zip(self.layers, cache): h = layer(h, mask, c) - return self.norm(h) \ No newline at end of file + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen3NextModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): + out = self.model(inputs, mask, cache) + return self.lm_head(out) + + @property + def layers(self): + return self.model.layers From 8aa201770bc28ac097e9baef52fab751505efcc1 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 18:41:44 +0200 Subject: [PATCH 06/44] adding MLP --- mlx_lm/models/qwen3_next.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index d339b0902..4cc32e5cc 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -105,6 +105,17 @@ def __call__( return self.o_proj(output) +class Qwen3NextMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + class Qwen3NextDecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() From 0dd5093bde75e19c9a367af55ae5991054c1da01 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 19:00:07 +0200 Subject: [PATCH 07/44] adding Qwen3NextGatedDeltaNet --- mlx_lm/models/qwen3_next.py | 178 +++++++++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 4cc32e5cc..e45aa3965 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -17,6 +17,11 @@ class ModelArgs(BaseModelArgs): num_hidden_layers: int intermediate_size: int num_attention_heads: int + linear_num_value_heads: int + linear_num_key_heads: int + linear_key_head_dim: int + linear_value_head_dim: int + linear_conv_kernel_dim: int num_experts: int num_experts_per_tok: int decoder_sparse_step: int @@ -34,7 +39,65 @@ class ModelArgs(BaseModelArgs): rope_scaling: Optional[Dict[str, Union[float, str]]] = None -class MambaRMSNormGated(nn.Module): +def recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query_norm = mx.sqrt(mx.sum(query**2, axis=-1, keepdims=True) + 1e-12) + query = query / query_norm + key_norm = mx.sqrt(mx.sum(key**2, axis=-1, keepdims=True) + 1e-12) + key = key / key_norm + + query, key, value, beta, g = [ + mx.transpose(x, (0, 2, 1, 3)).astype(mx.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = mx.zeros((batch_size, num_heads, sequence_length, v_head_dim), dtype=value.dtype) + last_recurrent_state = ( + mx.zeros((batch_size, sequence_length, k_head_dim, v_head_dim), dtype=value.dtype) + if initial_state is None + else initial_state.astype(value.dtype) + ) + + for i in range(num_heads): + q_t = query[:, i, :, :] + k_t = key[:, i, :, :] + v_t = value[:, i, :, :] + g_t = mx.exp(g[:, i, :]).reshape(batch_size, sequence_length, 1, 1) + beta_t = beta[:, i, :].reshape(batch_size, sequence_length, 1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = mx.sum(last_recurrent_state * k_t.reshape(batch_size, sequence_length, k_head_dim, 1), axis=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.reshape(batch_size, sequence_length, k_head_dim, 1) * delta.reshape(batch_size, sequence_length, 1, v_head_dim) + core_attn_out = core_attn_out.at[:, i, :, :].set( + mx.sum(last_recurrent_state * q_t.reshape(batch_size, sequence_length, k_head_dim, 1), axis=-2) + ) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = mx.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + return core_attn_out, last_recurrent_state + +def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.array): + if ( + attention_mask is not None + and attention_mask.shape[0] > 1 + and attention_mask.shape[1] > 1 + ): + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).astype(dtype) + + return hidden_states + +class Qwen3NextRMSNormGated(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.eps = eps @@ -116,6 +179,119 @@ def __call__(self, x) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) +class Qwen3NextGatedDeltaNet(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False) + self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False) + + self.dt_bias = mx.ones(self.num_v_heads) + + A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,)) + self.A_log = mx.log(A) + + self.norm = Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): + hidden_states = apply_mask_to_padding_states(inputs, mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + conv_state, recurrent_state = None, None + + # getting projected states from cache if it exists + if cache is not None: + # decoding + conv_state = cache.conv_states[self.layer_idx] + recurrent_state = cache.recurrent_states[self.layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = mx.concatenate((query, key, value), axis=-1) + mixed_qkv = mx.transpose(mixed_qkv, (0, 2, 1)) + + if cache is not None: + conv_state = mx.pad(mixed_qkv, [(0, 0), (0, 0), (self.conv_kernel_size - mixed_qkv.shape[-1], 0)]) + cache.conv_states[self.layer_idx] = conv_state + else: + mixed_qkv = mx.sigmoid(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mx.transpose(mixed_qkv, (0, 2, 1)) + query, key, value = mx.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + axis=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = mx.sigmoid(b) + # If the model is loaded in fp16, without the .astype(mx.float32) here, A might be -inf + g = -mx.exp(self.A_log.astype(mx.float32)) * mx.softplus(a.astype(mx.float32) + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = mx.repeat(query, self.num_v_heads // self.num_k_heads, axis=2) + key = mx.repeat(key, self.num_v_heads // self.num_k_heads, axis=2) + + core_attn_out, _ = recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=recurrent_state is not None, + use_qk_l2norm_in_kernel=True, + ) + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + return self.out_proj(core_attn_out) + + class Qwen3NextDecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() From 91f527f0dd616a74f6d8791f47d864e88b6f5824 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 21:37:01 +0200 Subject: [PATCH 08/44] updates --- mlx_lm/models/qwen3_next.py | 106 +++++++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index e45aa3965..ebe622d38 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -8,6 +8,7 @@ from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import KVCache, MambaCache +from .switch_layers import SwitchGLU @dataclass @@ -25,6 +26,7 @@ class ModelArgs(BaseModelArgs): num_experts: int num_experts_per_tok: int decoder_sparse_step: int + shared_expert_intermediate_size: int mlp_only_layers: List[int] moe_intermediate_size: int rms_norm_eps: float @@ -38,7 +40,7 @@ class ModelArgs(BaseModelArgs): layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None - +@mx.compile def recurrent_gated_delta_rule( query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False ): @@ -86,6 +88,8 @@ def recurrent_gated_delta_rule( core_attn_out = mx.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) return core_attn_out, last_recurrent_state + +@mx.compile def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.array): if ( attention_mask is not None @@ -97,6 +101,7 @@ def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.arr return hidden_states + class Qwen3NextRMSNormGated(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() @@ -200,7 +205,7 @@ def __init__(self, config: ModelArgs): bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, - padding=self.conv_kernel_size - 1, + padding=0, ) projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 @@ -225,31 +230,33 @@ def __call__( ): hidden_states = apply_mask_to_padding_states(inputs, mask) - # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - conv_state, recurrent_state = None, None - - # getting projected states from cache if it exists - if cache is not None: - # decoding - conv_state = cache.conv_states[self.layer_idx] - recurrent_state = cache.recurrent_states[self.layer_idx] + recurrent_state = cache[1] if cache is not None else None projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) - mixed_qkv = mx.concatenate((query, key, value), axis=-1) - mixed_qkv = mx.transpose(mixed_qkv, (0, 2, 1)) - + mixed_qkv = mx.concatenate((query, key, value), axis=-1) # (batch_size, seq_len, self.conv_dim) + if cache is not None: - conv_state = mx.pad(mixed_qkv, [(0, 0), (0, 0), (self.conv_kernel_size - mixed_qkv.shape[-1], 0)]) - cache.conv_states[self.layer_idx] = conv_state + if cache[0] is None: + conv_state = mx.zeros( + (batch_size, self.conv_kernel_size - 1, self.conv_dim), + dtype=hidden_states.dtype, + ) + else: + conv_state = cache[0] + padded_input = mx.concatenate([conv_state, mixed_qkv], axis=1) + cache[0] = padded_input[:, -(self.conv_kernel_size - 1):, :] else: - mixed_qkv = mx.sigmoid(self.conv1d(mixed_qkv)[:, :, :seq_len]) - - mixed_qkv = mx.transpose(mixed_qkv, (0, 2, 1)) + padded_input = mx.pad( + mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)] + ) + + # Depthwise conv, keep length = seq_len + mixed_qkv = mx.sigmoid(self.conv1d(padded_input)) query, key, value = mx.split( mixed_qkv, [ @@ -264,25 +271,27 @@ def __call__( value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) beta = mx.sigmoid(b) - # If the model is loaded in fp16, without the .astype(mx.float32) here, A might be -inf - g = -mx.exp(self.A_log.astype(mx.float32)) * mx.softplus(a.astype(mx.float32) + self.dt_bias) + + g = -mx.exp(self.A_log.astype(mx.float32)) * nn.softplus(a.astype(mx.float32) + self.dt_bias) if self.num_v_heads // self.num_k_heads > 1: query = mx.repeat(query, self.num_v_heads // self.num_k_heads, axis=2) key = mx.repeat(key, self.num_v_heads // self.num_k_heads, axis=2) - core_attn_out, _ = recurrent_gated_delta_rule( + core_attn_out, new_recurrent_state = recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=recurrent_state, - output_final_state=recurrent_state is not None, + output_final_state=True if cache is not None else False, use_qk_l2norm_in_kernel=True, ) + if cache is not None: + cache[1] = new_recurrent_state z_shape_og = z.shape - # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) @@ -292,14 +301,52 @@ def __call__( return self.out_proj(core_attn_out) +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + + self.gate = nn.Linear(dim, num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + self.shared_expert = Qwen3NextMLP(dim, shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(dim, 1, bias=False) + + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + shared_expert_output = self.shared_expert(x) + shared_expert_output = ( + mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output + ) + + return y + shared_expert_output + + class Qwen3NextDecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() self.layer_type = args.layer_types[layer_idx] if self.layer_type == "linear_attention": - self.linear_attn = Qwen3NextGatedDeltaNet(args, layer_idx) + self.linear_attn = Qwen3NextGatedDeltaNet(args) elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention(args, layer_idx) + self.self_attn = Qwen3NextAttention(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( @@ -384,3 +431,12 @@ def __call__( @property def layers(self): return self.model.layers + + def make_cache(self): + caches = [] + for l in self.layers: + if l.layer_type == "linear_attention": + caches.append(MambaCache()) + elif l.layer_type == "full_attention": + caches.append(KVCache()) + return caches \ No newline at end of file From 416e0c7f2b4c747020606060404bf46b4a96ae5c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 21:41:53 +0200 Subject: [PATCH 09/44] updates --- mlx_lm/models/qwen3_next.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index ebe622d38..ca2f2ce5b 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -389,26 +389,44 @@ def __init__(self, args: ModelArgs): ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fa_idx = 0 + for b in args.layer_types: + if b == "linear_attention": + break + elif b == "full_attention": + self.fa_idx += 1 + def __call__( self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ): - h = self.embed_tokens(inputs) + hidden_states = self.embed_tokens(inputs) if mask is None: - mask = create_attention_mask(h, cache) - - # TODO add linear mask + attn_mask = create_attention_mask( + hidden_states, cache[self.fa_idx : self.fa_idx + 1] + ) if cache is None: cache = [None] * len(self.layers) + + cache_counter = 0 + for layer in self.layers: + if layer.layer_type == "linear_attention" or layer.layer_type == "linear_attention": + c = cache[cache_counter] + cache_counter += 1 + else: + c = None - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + if layer.layer_type == "full_attention": + mask_to_use = attn_mask + else: + mask_to_use = None + hidden_states = layer(hidden_states, mask=mask_to_use, cache=c) - return self.norm(h) + return self.norm(hidden_states) class Model(nn.Module): From 936a72a369c5862b5e8c2bfd898775d77512f982 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 21:45:35 +0200 Subject: [PATCH 10/44] upd. ackn. --- ACKNOWLEDGMENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 060c71bcd..2fc59b4a0 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,5 +8,5 @@ with a short description of your contribution(s) below. For example: MLX LM was developed with contributions from the following individuals: - Shunta Saito: Added support for PLaMo models. -- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, and Allenai's `OLMoE`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`. +- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE`; Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`. - Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`. From a7200535d9a3e8b489bb94b1bd9e038579e7077c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 21:48:10 +0200 Subject: [PATCH 11/44] nits --- mlx_lm/models/qwen3_next.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index ca2f2ce5b..d2727fdec 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -78,8 +78,9 @@ def recurrent_gated_delta_rule( kv_mem = mx.sum(last_recurrent_state * k_t.reshape(batch_size, sequence_length, k_head_dim, 1), axis=-2) delta = (v_t - kv_mem) * beta_t last_recurrent_state = last_recurrent_state + k_t.reshape(batch_size, sequence_length, k_head_dim, 1) * delta.reshape(batch_size, sequence_length, 1, v_head_dim) - core_attn_out = core_attn_out.at[:, i, :, :].set( - mx.sum(last_recurrent_state * q_t.reshape(batch_size, sequence_length, k_head_dim, 1), axis=-2) + core_attn_out[:, i, :, :] = mx.sum( + last_recurrent_state * q_t.reshape(batch_size, sequence_length, k_head_dim, 1), + axis=-2, ) if not output_final_state: From 089c2be7bb45e8f6af89f5959f08729dbcd6a652 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 21:51:28 +0200 Subject: [PATCH 12/44] making it trainable --- mlx_lm/tuner/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index fa21f5429..2af584688 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -128,6 +128,7 @@ def to_lora(layer): "longcat_flash", "seed_oss", "apertus", + "qwen3_next", }: keys = {"self_attn.q_proj", "self_attn.v_proj"} if model.model_type in ["mixtral", "phimoe"]: @@ -181,6 +182,10 @@ def to_lora(layer): keys = {"attn.attention.q_proj", "attn.attention.v_proj"} elif model.model_type == "bailing_moe": keys = {"attention.query_key_value", "attention.dense"} + elif model.model_type == "qwen3_next": + keys.add("self_attn.in_proj_qkvz") + keys.add("self_attn.in_proj_ba") + keys.add("self_attn.out_proj") elif model.model_type == "nemotron_h": keys.add("mixer.in_proj") keys.add("mixer.out_proj") From 222627ff6b00602940d7a3778c59d2567e134e59 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 23:06:15 +0200 Subject: [PATCH 13/44] inference fix --- mlx_lm/models/qwen3_next.py | 201 +++++++++++++++++++++++++----------- 1 file changed, 138 insertions(+), 63 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index d2727fdec..6cdb4707b 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -1,7 +1,7 @@ # Copyright © 2025 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple import mlx.core as mx import mlx.nn as nn @@ -40,58 +40,88 @@ class ModelArgs(BaseModelArgs): layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None + @mx.compile def recurrent_gated_delta_rule( - query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False -): - initial_dtype = query.dtype + query: mx.array, + key: mx.array, + value: mx.array, + g: mx.array, + beta: mx.array, + initial_state: Optional[mx.array] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False +) -> Tuple[mx.array, Optional[mx.array]]: if use_qk_l2norm_in_kernel: - query_norm = mx.sqrt(mx.sum(query**2, axis=-1, keepdims=True) + 1e-12) - query = query / query_norm - key_norm = mx.sqrt(mx.sum(key**2, axis=-1, keepdims=True) + 1e-12) - key = key / key_norm + query = query / mx.linalg.norm(query, axis=-1, keepdims=True) + key = key / mx.linalg.norm(key, axis=-1, keepdims=True) + + # Transpose and convert to float32 + query = mx.transpose(query, (0, 2, 1, 3)).astype(mx.float32) + key = mx.transpose(key, (0, 2, 1, 3)).astype(mx.float32) + value = mx.transpose(value, (0, 2, 1, 3)).astype(mx.float32) + g = mx.transpose(g, (0, 2, 1, 3)).astype(mx.float32) + beta = mx.transpose(beta, (0, 2, 1, 3)).astype(mx.float32) + + # Scale query + query = query / (key.shape[-1] ** 0.5) - query, key, value, beta, g = [ - mx.transpose(x, (0, 2, 1, 3)).astype(mx.float32) for x in (query, key, value, beta, g) - ] + B, H, T, Dk = key.shape + _, _, _, Dv = value.shape - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - scale = 1 / (query.shape[-1] ** 0.5) - query = query * scale + # Initialize state + if initial_state is None: + s = mx.zeros((B, H, Dk, Dv), dtype=mx.float32) + else: + s = initial_state.astype(mx.float32) + # Ensure s has the correct shape + if s.shape != (B, H, Dk, Dv): + s = s.reshape(B, H, Dk, Dv) - core_attn_out = mx.zeros((batch_size, num_heads, sequence_length, v_head_dim), dtype=value.dtype) - last_recurrent_state = ( - mx.zeros((batch_size, sequence_length, k_head_dim, v_head_dim), dtype=value.dtype) - if initial_state is None - else initial_state.astype(value.dtype) - ) + out = [] - for i in range(num_heads): - q_t = query[:, i, :, :] - k_t = key[:, i, :, :] - v_t = value[:, i, :, :] - g_t = mx.exp(g[:, i, :]).reshape(batch_size, sequence_length, 1, 1) - beta_t = beta[:, i, :].reshape(batch_size, sequence_length, 1) + for t in range(T): + k_t = key[:, :, t] # (B, H, Dk) + v_t = value[:, :, t] # (B, H, Dv) + q_t = query[:, :, t] # (B, H, Dk) + b_t = beta[:, :, t] # (B, H, 1) or (B, H,) + g_t = g[:, :, t] # (B, H, 1) or (B, H,) - last_recurrent_state = last_recurrent_state * g_t - kv_mem = mx.sum(last_recurrent_state * k_t.reshape(batch_size, sequence_length, k_head_dim, 1), axis=-2) - delta = (v_t - kv_mem) * beta_t - last_recurrent_state = last_recurrent_state + k_t.reshape(batch_size, sequence_length, k_head_dim, 1) * delta.reshape(batch_size, sequence_length, 1, v_head_dim) - core_attn_out[:, i, :, :] = mx.sum( - last_recurrent_state * q_t.reshape(batch_size, sequence_length, k_head_dim, 1), - axis=-2, - ) + # Ensure beta and g have correct shapes for broadcasting + if b_t.ndim == 2: # (B, H,) + b_t = b_t[:, :, None] # (B, H, 1) + if g_t.ndim == 2: # (B, H,) + g_t = g_t[:, :, None, None] # (B, H, 1, 1) + elif g_t.ndim == 3: # (B, H, 1) + g_t = g_t[:, :, :, None] # (B, H, 1, 1) + + # Decay (forget gate) - ensure g_t broadcasts correctly + s = s * mx.exp(g_t) + + # Read - this should work with s: (B, H, Dk, Dv) and k_t: (B, H, Dk) + mem = mx.einsum("bhkd,bhk->bhd", s, k_t) # (B, H, Dv) + + # Write + delta = (v_t - mem) * b_t # (B, H, Dv) + + # Update state: k_t[:,:,:,None] is (B, H, Dk, 1), delta[:,:,None,:] is (B, H, 1, Dv) + s = s + k_t[:, :, :, None] * delta[:, :, None, :] + + # Project to output + output_t = mx.einsum("bhkd,bhk->bhd", s, q_t) # (B, H, Dv) + out.append(output_t) - if not output_final_state: - last_recurrent_state = None + # Stack outputs and transpose back + output = mx.stack(out, axis=2) # (B, H, T, Dv) + output = mx.transpose(output, (0, 2, 1, 3)) # (B, T, H, Dv) - core_attn_out = mx.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) - return core_attn_out, last_recurrent_state + final_state = s if output_final_state else None + + return output, final_state @mx.compile -def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.array): +def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.array) -> mx.array: if ( attention_mask is not None and attention_mask.shape[0] > 1 @@ -223,23 +253,48 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + nq, nk, nv, dv = self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim + mixed_qkvz = mixed_qkvz.reshape(mixed_qkvz.shape[:-1] + (nq, 2*nk + 2*nv*dv//nq)) + mixed_ba = mixed_ba.reshape(mixed_ba.shape[:-1] + (nq, 2*nv//nq)) + + # Split indices are cumulative positions + q, k, v, z = mx.split(mixed_qkvz, [nk, 2*nk, 2*nk + nv//nq*dv], axis=-1) + b, a = mx.split(mixed_ba, [nv//nq], axis=-1) + + v = v.reshape(v.shape[0], v.shape[1], -1, dv) + z = z.reshape(z.shape[0], z.shape[1], -1, dv) + b = b.reshape(b.shape[0], b.shape[1], nv) + a = a.reshape(a.shape[0], a.shape[1], nv) + return q, k, v, z, b, a + def __call__( self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ): - hidden_states = apply_mask_to_padding_states(inputs, mask) - + if mask is not None: + hidden_states = apply_mask_to_padding_states(inputs, mask) + else: + hidden_states = inputs + batch_size, seq_len, _ = hidden_states.shape - recurrent_state = cache[1] if cache is not None else None - + + if cache is not None and cache[1] is not None: + recurrent_state = cache[1] + else: + recurrent_state = mx.zeros( + (batch_size, self.num_k_heads, self.head_k_dim, self.head_v_dim), + dtype=hidden_states.dtype, + ) + projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) - - mixed_qkv = mx.concatenate((query, key, value), axis=-1) # (batch_size, seq_len, self.conv_dim) + + mixed_qkv = mx.concatenate((query, key, value), axis=-1) if cache is not None: if cache[0] is None: @@ -256,28 +311,26 @@ def __call__( mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)] ) - # Depthwise conv, keep length = seq_len mixed_qkv = mx.sigmoid(self.conv1d(padded_input)) + query, key, value = mx.split( mixed_qkv, - [ - self.key_dim, - self.key_dim, - self.value_dim, - ], + [self.key_dim, self.key_dim + self.key_dim], axis=-1, ) query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) - - beta = mx.sigmoid(b) - g = -mx.exp(self.A_log.astype(mx.float32)) * nn.softplus(a.astype(mx.float32) + self.dt_bias) + beta = mx.sigmoid(b).reshape(batch_size, seq_len, -1, 1) + g = ( + -mx.exp(self.A_log.astype(mx.float32)) + * nn.softplus(a.astype(mx.float32) + self.dt_bias) + ).reshape(batch_size, seq_len, -1, 1) if self.num_v_heads // self.num_k_heads > 1: query = mx.repeat(query, self.num_v_heads // self.num_k_heads, axis=2) key = mx.repeat(key, self.num_v_heads // self.num_k_heads, axis=2) - + core_attn_out, new_recurrent_state = recurrent_gated_delta_rule( query, key, @@ -288,9 +341,12 @@ def __call__( output_final_state=True if cache is not None else False, use_qk_l2norm_in_kernel=True, ) + # Updated storage of new_recurrent_state if cache is not None: cache[1] = new_recurrent_state - + else: + new_recurrent_state = None + z_shape_og = z.shape core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) @@ -298,7 +354,7 @@ def __call__( core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) - + return self.out_proj(core_attn_out) @@ -406,16 +462,18 @@ def __call__( hidden_states = self.embed_tokens(inputs) if mask is None: - attn_mask = create_attention_mask( - hidden_states, cache[self.fa_idx : self.fa_idx + 1] - ) + kv_cache = next(c for c in cache if isinstance(c, KVCache)) + attn_mask = create_attention_mask(hidden_states, [kv_cache]) if cache is None: cache = [None] * len(self.layers) cache_counter = 0 for layer in self.layers: - if layer.layer_type == "linear_attention" or layer.layer_type == "linear_attention": + if layer.layer_type == "linear_attention": + c = cache[cache_counter] + cache_counter += 1 + elif layer.layer_type == "full_attention": c = cache[cache_counter] cache_counter += 1 else: @@ -458,4 +516,21 @@ def make_cache(self): caches.append(MambaCache()) elif l.layer_type == "full_attention": caches.append(KVCache()) - return caches \ No newline at end of file + return caches + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + if f"{prefix}.mlp.experts.0.{n}.weight" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + return weights \ No newline at end of file From 4176d9d24ebcdf9c54b99bccd28fcea57c50960a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 23:43:27 +0200 Subject: [PATCH 14/44] gibberish inference --- mlx_lm/models/qwen3_next.py | 99 +++++++++++++------------------------ 1 file changed, 33 insertions(+), 66 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 6cdb4707b..aafac808e 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -43,81 +43,48 @@ class ModelArgs(BaseModelArgs): @mx.compile def recurrent_gated_delta_rule( - query: mx.array, - key: mx.array, - value: mx.array, - g: mx.array, - beta: mx.array, - initial_state: Optional[mx.array] = None, - output_final_state: bool = False, + query: mx.array, key: mx.array, value: mx.array, g: mx.array, beta: mx.array, + initial_state: Optional[mx.array] = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False ) -> Tuple[mx.array, Optional[mx.array]]: + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: query = query / mx.linalg.norm(query, axis=-1, keepdims=True) key = key / mx.linalg.norm(key, axis=-1, keepdims=True) - # Transpose and convert to float32 - query = mx.transpose(query, (0, 2, 1, 3)).astype(mx.float32) - key = mx.transpose(key, (0, 2, 1, 3)).astype(mx.float32) - value = mx.transpose(value, (0, 2, 1, 3)).astype(mx.float32) - g = mx.transpose(g, (0, 2, 1, 3)).astype(mx.float32) - beta = mx.transpose(beta, (0, 2, 1, 3)).astype(mx.float32) - - # Scale query - query = query / (key.shape[-1] ** 0.5) + # Transpose to match PyTorch: (B, H, T, D) + query, key, value, beta, g = [mx.transpose(x, (0, 2, 1, 3)).astype(mx.float32) + for x in (query, key, value, beta, g)] B, H, T, Dk = key.shape - _, _, _, Dv = value.shape + Dv = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale - # Initialize state if initial_state is None: - s = mx.zeros((B, H, Dk, Dv), dtype=mx.float32) + state = mx.zeros((B, H, Dk, Dv), dtype=mx.float32) else: - s = initial_state.astype(mx.float32) - # Ensure s has the correct shape - if s.shape != (B, H, Dk, Dv): - s = s.reshape(B, H, Dk, Dv) - - out = [] + state = initial_state.astype(mx.float32) + if len(state.shape) == 4 and state.shape[1] == T: + state = state[:, -1, :, :].reshape(B, H, Dk, Dv) + else: + state = state.reshape(B, H, Dk, Dv) + outputs = [] for t in range(T): - k_t = key[:, :, t] # (B, H, Dk) - v_t = value[:, :, t] # (B, H, Dv) - q_t = query[:, :, t] # (B, H, Dk) - b_t = beta[:, :, t] # (B, H, 1) or (B, H,) - g_t = g[:, :, t] # (B, H, 1) or (B, H,) - - # Ensure beta and g have correct shapes for broadcasting - if b_t.ndim == 2: # (B, H,) - b_t = b_t[:, :, None] # (B, H, 1) - if g_t.ndim == 2: # (B, H,) - g_t = g_t[:, :, None, None] # (B, H, 1, 1) - elif g_t.ndim == 3: # (B, H, 1) - g_t = g_t[:, :, :, None] # (B, H, 1, 1) - - # Decay (forget gate) - ensure g_t broadcasts correctly - s = s * mx.exp(g_t) - - # Read - this should work with s: (B, H, Dk, Dv) and k_t: (B, H, Dk) - mem = mx.einsum("bhkd,bhk->bhd", s, k_t) # (B, H, Dv) + g_t = mx.exp(g[:, :, t, :]) # exp of g, not -exp + g_t = mx.expand_dims(g_t, -1) # (B, H, Dv, 1) - # Write - delta = (v_t - mem) * b_t # (B, H, Dv) - - # Update state: k_t[:,:,:,None] is (B, H, Dk, 1), delta[:,:,None,:] is (B, H, 1, Dv) - s = s + k_t[:, :, :, None] * delta[:, :, None, :] - - # Project to output - output_t = mx.einsum("bhkd,bhk->bhd", s, q_t) # (B, H, Dv) - out.append(output_t) - - # Stack outputs and transpose back - output = mx.stack(out, axis=2) # (B, H, T, Dv) - output = mx.transpose(output, (0, 2, 1, 3)) # (B, T, H, Dv) - - final_state = s if output_final_state else None + state = state * g_t + mem = mx.einsum("bhkv,bhk->bhv", state, key[:, :, t]) + delta = (value[:, :, t] - mem) * beta[:, :, t] + state = state + mx.einsum("bhk,bhv->bhkv", key[:, :, t], delta) + outputs.append(mx.einsum("bhkv,bhk->bhv", state, query[:, :, t])) - return output, final_state + out = mx.transpose(mx.stack(outputs, axis=2), (0, 2, 1, 3)) + return out.astype(initial_dtype), state if output_final_state else None @mx.compile @@ -230,6 +197,7 @@ def __init__(self, config: ModelArgs): self.layer_norm_epsilon = config.rms_norm_eps self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, @@ -285,7 +253,7 @@ def __call__( recurrent_state = cache[1] else: recurrent_state = mx.zeros( - (batch_size, self.num_k_heads, self.head_k_dim, self.head_v_dim), + (batch_size, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=hidden_states.dtype, ) @@ -306,12 +274,11 @@ def __call__( conv_state = cache[0] padded_input = mx.concatenate([conv_state, mixed_qkv], axis=1) cache[0] = padded_input[:, -(self.conv_kernel_size - 1):, :] + mixed_qkv = nn.silu(self.conv1d(padded_input)) else: - padded_input = mx.pad( - mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)] - ) - - mixed_qkv = mx.sigmoid(self.conv1d(padded_input)) + padded_input = mx.pad(mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) + conv_output = self.conv1d(padded_input) + mixed_qkv = nn.silu(conv_output[:, :seq_len, :]) query, key, value = mx.split( mixed_qkv, From 0c5507cf1dc23062536305b58c17939c68b00377 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 9 Sep 2025 23:46:09 +0200 Subject: [PATCH 15/44] fix training --- mlx_lm/models/qwen3_next.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index aafac808e..a5011a414 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -428,13 +428,15 @@ def __call__( ): hidden_states = self.embed_tokens(inputs) - if mask is None: - kv_cache = next(c for c in cache if isinstance(c, KVCache)) - attn_mask = create_attention_mask(hidden_states, [kv_cache]) - if cache is None: cache = [None] * len(self.layers) - + + attn_mask = None + if mask is None: + kv_caches = [c for c in cache if isinstance(c, KVCache)] + if kv_caches: + attn_mask = create_attention_mask(hidden_states, kv_caches) + cache_counter = 0 for layer in self.layers: if layer.layer_type == "linear_attention": @@ -448,6 +450,8 @@ def __call__( if layer.layer_type == "full_attention": mask_to_use = attn_mask + elif layer.layer_type == "linear_attention": + mask_to_use = mask else: mask_to_use = None hidden_states = layer(hidden_states, mask=mask_to_use, cache=c) From aa65e7d5d03b2f5afa20745659a16d8240106189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:55:17 +0200 Subject: [PATCH 16/44] fix for batching --- mlx_lm/models/qwen3_next.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index a5011a414..d347be466 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -423,7 +423,6 @@ def __init__(self, args: ModelArgs): def __call__( self, inputs: mx.array, - mask: Optional[mx.array] = None, cache: Optional[Any] = None, ): hidden_states = self.embed_tokens(inputs) @@ -432,10 +431,9 @@ def __call__( cache = [None] * len(self.layers) attn_mask = None - if mask is None: - kv_caches = [c for c in cache if isinstance(c, KVCache)] - if kv_caches: - attn_mask = create_attention_mask(hidden_states, kv_caches) + kv_caches = [c for c in cache if isinstance(c, KVCache)] + if kv_caches: + attn_mask = create_attention_mask(hidden_states, kv_caches) cache_counter = 0 for layer in self.layers: @@ -450,8 +448,6 @@ def __call__( if layer.layer_type == "full_attention": mask_to_use = attn_mask - elif layer.layer_type == "linear_attention": - mask_to_use = mask else: mask_to_use = None hidden_states = layer(hidden_states, mask=mask_to_use, cache=c) @@ -470,10 +466,9 @@ def __init__(self, args: ModelArgs): def __call__( self, inputs: mx.array, - mask: Optional[mx.array] = None, cache: Optional[Any] = None, ): - out = self.model(inputs, mask, cache) + out = self.model(inputs, cache) return self.lm_head(out) @property From fd6c1109fc8abc1c5041e872fd7d4fe759538e70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 11:49:26 +0200 Subject: [PATCH 17/44] nits --- mlx_lm/models/qwen3_next.py | 38 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index d347be466..c6343b407 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -249,14 +249,6 @@ def __call__( batch_size, seq_len, _ = hidden_states.shape - if cache is not None and cache[1] is not None: - recurrent_state = cache[1] - else: - recurrent_state = mx.zeros( - (batch_size, self.num_v_heads, self.head_k_dim, self.head_v_dim), - dtype=hidden_states.dtype, - ) - projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) @@ -264,21 +256,28 @@ def __call__( mixed_qkv = mx.concatenate((query, key, value), axis=-1) + # Explicit cache separation and handling if cache is not None: - if cache[0] is None: + conv_state, recurrent_state = cache + + # Conv-State handling + if conv_state is None: conv_state = mx.zeros( (batch_size, self.conv_kernel_size - 1, self.conv_dim), dtype=hidden_states.dtype, ) - else: - conv_state = cache[0] padded_input = mx.concatenate([conv_state, mixed_qkv], axis=1) - cache[0] = padded_input[:, -(self.conv_kernel_size - 1):, :] - mixed_qkv = nn.silu(self.conv1d(padded_input)) + conv_state = padded_input[:, -(self.conv_kernel_size - 1):, :] + conv_out = self.conv1d(padded_input) + mixed_qkv = nn.silu(conv_out[:, :seq_len, :]) + + # Update conv_state in cache + cache[0] = conv_state else: padded_input = mx.pad(mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_output = self.conv1d(padded_input) - mixed_qkv = nn.silu(conv_output[:, :seq_len, :]) + conv_out = self.conv1d(padded_input) + mixed_qkv = nn.silu(conv_out[:, :seq_len, :]) + recurrent_state = None query, key, value = mx.split( mixed_qkv, @@ -298,6 +297,12 @@ def __call__( query = mx.repeat(query, self.num_v_heads // self.num_k_heads, axis=2) key = mx.repeat(key, self.num_v_heads // self.num_k_heads, axis=2) + # Recurrent state explicit handling + if recurrent_state is None: + recurrent_state = mx.zeros( + (batch_size, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=hidden_states.dtype, + ) core_attn_out, new_recurrent_state = recurrent_gated_delta_rule( query, key, @@ -308,6 +313,7 @@ def __call__( output_final_state=True if cache is not None else False, use_qk_l2norm_in_kernel=True, ) + # Updated storage of new_recurrent_state if cache is not None: cache[1] = new_recurrent_state @@ -397,8 +403,6 @@ def __call__( r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) - if isinstance(r, tuple): - r, _ = r out = h + r return out From fa20e46c8fae92263bba882462a43eeb07987d1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:04:05 +0200 Subject: [PATCH 18/44] optimize --- mlx_lm/models/qwen3_next.py | 218 ++++++++++-------------------------- 1 file changed, 59 insertions(+), 159 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index c6343b407..b0a548b6c 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -47,57 +47,22 @@ def recurrent_gated_delta_rule( initial_state: Optional[mx.array] = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False ) -> Tuple[mx.array, Optional[mx.array]]: - - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = query / mx.linalg.norm(query, axis=-1, keepdims=True) - key = key / mx.linalg.norm(key, axis=-1, keepdims=True) - - # Transpose to match PyTorch: (B, H, T, D) - query, key, value, beta, g = [mx.transpose(x, (0, 2, 1, 3)).astype(mx.float32) - for x in (query, key, value, beta, g)] - - B, H, T, Dk = key.shape - Dv = value.shape[-1] - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale - - if initial_state is None: - state = mx.zeros((B, H, Dk, Dv), dtype=mx.float32) - else: - state = initial_state.astype(mx.float32) - if len(state.shape) == 4 and state.shape[1] == T: - state = state[:, -1, :, :].reshape(B, H, Dk, Dv) - else: - state = state.reshape(B, H, Dk, Dv) - - outputs = [] + query /= mx.linalg.norm(query, axis=-1, keepdims=True) + key /= mx.linalg.norm(key, axis=-1, keepdims=True) + query, key, value, beta, g = [mx.transpose(x, (0,2,1,3)).astype(mx.float32) for x in (query,key,value,beta,g)] + B,H,T,Dk = key.shape; Dv = value.shape[-1]; scale = 1.0 / (query.shape[-1]**0.5) + query *= scale + state = mx.zeros((B,H,Dk,Dv), dtype=mx.float32) if initial_state is None else initial_state.astype(mx.float32).reshape(B,H,Dk,Dv) + outs = [] for t in range(T): - g_t = mx.exp(g[:, :, t, :]) # exp of g, not -exp - g_t = mx.expand_dims(g_t, -1) # (B, H, Dv, 1) - - state = state * g_t - mem = mx.einsum("bhkv,bhk->bhv", state, key[:, :, t]) - delta = (value[:, :, t] - mem) * beta[:, :, t] - state = state + mx.einsum("bhk,bhv->bhkv", key[:, :, t], delta) - outputs.append(mx.einsum("bhkv,bhk->bhv", state, query[:, :, t])) - - out = mx.transpose(mx.stack(outputs, axis=2), (0, 2, 1, 3)) - return out.astype(initial_dtype), state if output_final_state else None - - -@mx.compile -def apply_mask_to_padding_states(hidden_states: mx.array, attention_mask: mx.array) -> mx.array: - if ( - attention_mask is not None - and attention_mask.shape[0] > 1 - and attention_mask.shape[1] > 1 - ): - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).astype(dtype) - - return hidden_states + state = state * mx.expand_dims(mx.exp(g[:, :, t, :]), -1) + mem = mx.einsum("bhkv,bhk->bhv", state, key[:,:,t]) + delta = (value[:,:,t] - mem) * beta[:,:,t] + state += mx.einsum("bhk,bhv->bhkv", key[:,:,t], delta) + outs.append(mx.einsum("bhkv,bhk->bhv", state, query[:,:,t])) + out = mx.transpose(mx.stack(outs, axis=2), (0,2,1,3)).astype(query.dtype) + return out, (state if output_final_state else None) class Qwen3NextRMSNormGated(nn.Module): @@ -223,112 +188,56 @@ def __init__(self, config: ModelArgs): def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): nq, nk, nv, dv = self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim - mixed_qkvz = mixed_qkvz.reshape(mixed_qkvz.shape[:-1] + (nq, 2*nk + 2*nv*dv//nq)) - mixed_ba = mixed_ba.reshape(mixed_ba.shape[:-1] + (nq, 2*nv//nq)) - - # Split indices are cumulative positions - q, k, v, z = mx.split(mixed_qkvz, [nk, 2*nk, 2*nk + nv//nq*dv], axis=-1) - b, a = mx.split(mixed_ba, [nv//nq], axis=-1) - - v = v.reshape(v.shape[0], v.shape[1], -1, dv) - z = z.reshape(z.shape[0], z.shape[1], -1, dv) - b = b.reshape(b.shape[0], b.shape[1], nv) - a = a.reshape(a.shape[0], a.shape[1], nv) - return q, k, v, z, b, a + mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nq, 2*nk + 2*nv*dv//nq) + mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nq, 2*nv//nq) + q,k,v,z = mx.split(mixed_qkvz,[nk,2*nk,2*nk+nv//nq*dv],axis=-1) + b,a = mx.split(mixed_ba,[nv//nq],axis=-1) + return ( + q, + k, + v.reshape(v.shape[0], v.shape[1], -1, dv), + z.reshape(z.shape[0], z.shape[1], -1, dv), + b.reshape(b.shape[0], b.shape[1], nv), + a.reshape(a.shape[0], a.shape[1], nv), + ) - def __call__( - self, - inputs: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ): - if mask is not None: - hidden_states = apply_mask_to_padding_states(inputs, mask) - else: - hidden_states = inputs - - batch_size, seq_len, _ = hidden_states.shape - - projected_states_qkvz = self.in_proj_qkvz(hidden_states) - projected_states_ba = self.in_proj_ba(hidden_states) - query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) - query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) - - mixed_qkv = mx.concatenate((query, key, value), axis=-1) + + def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): + B,L,_ = inputs.shape + qkvz, ba = self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) + q,k,v,z,b,a = self.fix_query_key_value_ordering(qkvz, ba) + q,k,v = (x.reshape(B,L,-1) for x in (q,k,v)) + mixed_qkv = mx.concatenate((q,k,v),-1) - # Explicit cache separation and handling if cache is not None: - conv_state, recurrent_state = cache - - # Conv-State handling + conv_state, rec_state = cache if conv_state is None: - conv_state = mx.zeros( - (batch_size, self.conv_kernel_size - 1, self.conv_dim), - dtype=hidden_states.dtype, - ) - padded_input = mx.concatenate([conv_state, mixed_qkv], axis=1) - conv_state = padded_input[:, -(self.conv_kernel_size - 1):, :] - conv_out = self.conv1d(padded_input) - mixed_qkv = nn.silu(conv_out[:, :seq_len, :]) - - # Update conv_state in cache - cache[0] = conv_state + conv_state = mx.zeros((B,self.conv_kernel_size-1,self.conv_dim),dtype=inputs.dtype) + padded = mx.concatenate([conv_state,mixed_qkv],1) + cache[0] = padded[:,-(self.conv_kernel_size-1):] + mixed_qkv = nn.silu(self.conv1d(padded)[:,:L]) else: - padded_input = mx.pad(mixed_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_out = self.conv1d(padded_input) - mixed_qkv = nn.silu(conv_out[:, :seq_len, :]) - recurrent_state = None - - query, key, value = mx.split( - mixed_qkv, - [self.key_dim, self.key_dim + self.key_dim], - axis=-1, - ) - query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) - key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) - value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) - - beta = mx.sigmoid(b).reshape(batch_size, seq_len, -1, 1) - g = ( - -mx.exp(self.A_log.astype(mx.float32)) - * nn.softplus(a.astype(mx.float32) + self.dt_bias) - ).reshape(batch_size, seq_len, -1, 1) - if self.num_v_heads // self.num_k_heads > 1: - query = mx.repeat(query, self.num_v_heads // self.num_k_heads, axis=2) - key = mx.repeat(key, self.num_v_heads // self.num_k_heads, axis=2) - - # Recurrent state explicit handling - if recurrent_state is None: - recurrent_state = mx.zeros( - (batch_size, self.num_v_heads, self.head_k_dim, self.head_v_dim), - dtype=hidden_states.dtype, - ) - core_attn_out, new_recurrent_state = recurrent_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=recurrent_state, - output_final_state=True if cache is not None else False, - use_qk_l2norm_in_kernel=True, - ) + padded = mx.pad(mixed_qkv,[(0,0),(self.conv_kernel_size-1,0),(0,0)]) + mixed_qkv = nn.silu(self.conv1d(padded)[:,:L]); rec_state=None - # Updated storage of new_recurrent_state - if cache is not None: - cache[1] = new_recurrent_state - else: - new_recurrent_state = None + q,k,v = mx.split(mixed_qkv,[self.key_dim,2*self.key_dim],-1) + q = q.reshape(B,L,-1,self.head_k_dim); k = k.reshape(B,L,-1,self.head_k_dim); v = v.reshape(B,L,-1,self.head_v_dim) - z_shape_og = z.shape + beta = mx.sigmoid(b).reshape(B,L,-1,1) + g = (-mx.exp(self.A_log.astype(mx.float32))*nn.softplus(a.astype(mx.float32)+self.dt_bias)).reshape(B,L,-1,1) + if self.num_v_heads//self.num_k_heads>1: + q = mx.repeat(q,self.num_v_heads//self.num_k_heads,axis=2) + k = mx.repeat(k,self.num_v_heads//self.num_k_heads,axis=2) - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + if rec_state is None: + rec_state = mx.zeros((B,self.num_v_heads,self.head_k_dim,self.head_v_dim),dtype=inputs.dtype) + out,new_state = recurrent_gated_delta_rule(q,k,v,g,beta,rec_state,cache is not None,True) - return self.out_proj(core_attn_out) + if cache is not None: cache[1]=new_state + else: new_state=None + + out = self.norm(out.reshape(-1,out.shape[-1]),z.reshape(-1,z.shape[-1])).reshape(z.shape[0],z.shape[1],-1) + return self.out_proj(out) class Qwen3NextSparseMoeBlock(nn.Module): @@ -417,13 +326,6 @@ def __init__(self, args: ModelArgs): ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.fa_idx = 0 - for b in args.layer_types: - if b == "linear_attention": - break - elif b == "full_attention": - self.fa_idx += 1 - def __call__( self, inputs: mx.array, @@ -434,11 +336,6 @@ def __call__( if cache is None: cache = [None] * len(self.layers) - attn_mask = None - kv_caches = [c for c in cache if isinstance(c, KVCache)] - if kv_caches: - attn_mask = create_attention_mask(hidden_states, kv_caches) - cache_counter = 0 for layer in self.layers: if layer.layer_type == "linear_attention": @@ -450,8 +347,11 @@ def __call__( else: c = None + # Compute attention mask per layer as needed if layer.layer_type == "full_attention": - mask_to_use = attn_mask + mask_to_use = create_attention_mask(hidden_states, c) + elif layer.layer_type == "linear_attention": + mask_to_use = None else: mask_to_use = None hidden_states = layer(hidden_states, mask=mask_to_use, cache=c) From f95f3fe2ce25624c1f3655f39bd7ee38fed6d959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:54:07 +0200 Subject: [PATCH 19/44] updates --- mlx_lm/models/qwen3_next.py | 93 ++++++++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index b0a548b6c..3d98417fa 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -43,26 +43,83 @@ class ModelArgs(BaseModelArgs): @mx.compile def recurrent_gated_delta_rule( - query: mx.array, key: mx.array, value: mx.array, g: mx.array, beta: mx.array, - initial_state: Optional[mx.array] = None, output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False + query: mx.array, + key: mx.array, + value: mx.array, + g: mx.array, + beta: mx.array, + initial_state: Optional[mx.array] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[mx.array, Optional[mx.array]]: + """Minimal recurrent gated delta rule in MLX matching the Torch reference. + Expects query/key/value shapes (B, S, H, D*) and g/beta shapes (B, S, H) or (B, S, H, 1). + """ + # Optional L2 normalization on last dim for query/key if use_qk_l2norm_in_kernel: - query /= mx.linalg.norm(query, axis=-1, keepdims=True) - key /= mx.linalg.norm(key, axis=-1, keepdims=True) - query, key, value, beta, g = [mx.transpose(x, (0,2,1,3)).astype(mx.float32) for x in (query,key,value,beta,g)] - B,H,T,Dk = key.shape; Dv = value.shape[-1]; scale = 1.0 / (query.shape[-1]**0.5) - query *= scale - state = mx.zeros((B,H,Dk,Dv), dtype=mx.float32) if initial_state is None else initial_state.astype(mx.float32).reshape(B,H,Dk,Dv) - outs = [] - for t in range(T): - state = state * mx.expand_dims(mx.exp(g[:, :, t, :]), -1) - mem = mx.einsum("bhkv,bhk->bhv", state, key[:,:,t]) - delta = (value[:,:,t] - mem) * beta[:,:,t] - state += mx.einsum("bhk,bhv->bhkv", key[:,:,t], delta) - outs.append(mx.einsum("bhkv,bhk->bhv", state, query[:,:,t])) - out = mx.transpose(mx.stack(outs, axis=2), (0,2,1,3)).astype(query.dtype) - return out, (state if output_final_state else None) + # Normalize along the feature dimension + query = query / mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) + key = key / mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) + + # Cast to float32 for numerical stability (like Torch .to(torch.float32)) + query = query.astype(mx.float32) + key = key.astype(mx.float32) + value = value.astype(mx.float32) + beta = beta.astype(mx.float32) + g = g.astype(mx.float32) + + # Allow beta and g to come with an extra trailing singleton dim: (B,S,H,1) + if beta.ndim == 4 and beta.shape[-1] == 1: + beta = beta.squeeze(-1) + if g.ndim == 4 and g.shape[-1] == 1: + g = g.squeeze(-1) + + B, S, H, Dk = key.shape + Dv = value.shape[-1] + + # Scale queries by 1/sqrt(Dq) (Dq == last dim of query) + scale = 1.0 / mx.sqrt(mx.array(query.shape[-1], dtype=mx.float32)) + query = query * scale + + # Precompute value*beta and key*beta to match the Torch reference + v_beta = value * beta[..., None] + k_beta = key * beta[..., None] + + # Initialize state: (B, H, Dk, Dv) + if initial_state is None: + state = mx.zeros((B, H, Dk, Dv), dtype=value.dtype) + else: + state = initial_state.astype(value.dtype) + if state.shape != (B, H, Dk, Dv): + state = state.reshape(B, H, Dk, Dv) + + # Output buffer: (B, S, H, Dv) + out = mx.zeros((B, S, H, Dv), dtype=value.dtype) + + for t in range(S): + q_t = query[:, t] # (B, H, Dk) + k_t = k_beta[:, t] # (B, H, Dk) + v_t = v_beta[:, t] # (B, H, Dv) + g_t = g[:, t] # (B, H) + + # decay = exp(g_t) + decay = mx.exp(g_t)[..., None] # (B, H, 1) + + # state = state * decay.unsqueeze(-1) + k_t.unsqueeze(-1) @ v_t.unsqueeze(-2) + state = state * decay[..., None] + mx.matmul( + k_t[..., None], # (B, H, Dk, 1) + v_t[..., None, :], # (B, H, 1, Dv) + ) + + # out[:, t] = einsum("bhd,bhdv->bhv", q_t, state) + out[:, t] = mx.einsum("bhd,bhdv->bhv", q_t, state) + + # Return (B, H, S, Dv) like Torch's out.transpose(1, 2) + out = mx.transpose(out, (0, 2, 1, 3)).astype(query.dtype) + + if not output_final_state: + state = None + return out, state class Qwen3NextRMSNormGated(nn.Module): From 3864198614c464cf3a241e04598b21a889d3f689 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:58:20 +0200 Subject: [PATCH 20/44] closer --- mlx_lm/models/qwen3_next.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 3d98417fa..aab904d63 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -243,6 +243,18 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + def apply_mask_to_padding_states(self, hidden_states, attention_mask): + """ + Applies the mask to hidden_states like Torch reference. + If attention_mask is not None and has more than 1 row and column, multiply and cast. + """ + if attention_mask is not None: + if attention_mask.shape[0] > 1 and attention_mask.shape[1] > 1: + dtype = hidden_states.dtype + # Mask shape: (B, S), hidden_states: (B, S, ...) + hidden_states = (hidden_states * attention_mask[:, :, None]).astype(dtype) + return hidden_states + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): nq, nk, nv, dv = self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nq, 2*nk + 2*nv*dv//nq) @@ -258,8 +270,15 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): a.reshape(a.shape[0], a.shape[1], nv), ) - def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): + # Apply attention mask to inputs, if provided and valid + if mask is not None: + # mask shape (B, 1, 1, S) or similar + mask_2d = mask.squeeze(1).squeeze(1) + else: + mask_2d = None + inputs = self.apply_mask_to_padding_states(inputs, mask_2d) + B,L,_ = inputs.shape qkvz, ba = self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) q,k,v,z,b,a = self.fix_query_key_value_ordering(qkvz, ba) From 48f52229ab7e45bc79ff184c69bf010575500126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:29:19 +0200 Subject: [PATCH 21/44] upd. --- mlx_lm/models/qwen3_next.py | 202 +++++++++++++++++++++++------------- 1 file changed, 130 insertions(+), 72 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index aab904d63..03d668a12 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -7,10 +7,16 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import KVCache, MambaCache +from .rope_utils import initialize_rope +from .cache import KVCache, ArraysCache from .switch_layers import SwitchGLU +class MambaCache(ArraysCache): + def __init__(self): + super().__init__(size=2) + + @dataclass class ModelArgs(BaseModelArgs): model_type: str @@ -55,6 +61,7 @@ def recurrent_gated_delta_rule( """Minimal recurrent gated delta rule in MLX matching the Torch reference. Expects query/key/value shapes (B, S, H, D*) and g/beta shapes (B, S, H) or (B, S, H, 1). """ + orig_dtype = query.dtype # Optional L2 normalization on last dim for query/key if use_qk_l2norm_in_kernel: # Normalize along the feature dimension @@ -115,7 +122,7 @@ def recurrent_gated_delta_rule( out[:, t] = mx.einsum("bhd,bhdv->bhv", q_t, state) # Return (B, H, S, Dv) like Torch's out.transpose(1, 2) - out = mx.transpose(out, (0, 2, 1, 3)).astype(query.dtype) + out = mx.transpose(out, (0, 2, 1, 3)).astype(orig_dtype) if not output_final_state: state = None @@ -131,7 +138,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): def __call__(self, hidden_states: mx.array, gate: mx.array = None) -> mx.array: if gate is not None: hidden_states = hidden_states * nn.silu(gate) - return mx.fast.rms_norm(hidden_states, self.weight, self.eps) + return mx.fast.rms_norm(hidden_states, 1.0 + self.weight, self.eps) class Qwen3NextAttention(nn.Module): @@ -143,17 +150,19 @@ def __init__(self, args: ModelArgs): self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * self.head_dim * 2, bias=args.attention_bias) - self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, args.hidden_size, bias=False) + self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, args.hidden_size, bias=args.attention_bias) self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) - self.rope = nn.RoPE( + self.rope = initialize_rope( self.head_dim, - traditional=False, base=args.rope_theta, + traditional=False, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, ) def __call__( @@ -219,14 +228,13 @@ def __init__(self, config: ModelArgs): self.layer_norm_epsilon = config.rms_norm_eps self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = nn.Conv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, - padding=0, + padding=self.conv_kernel_size - 1, ) projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 @@ -244,10 +252,6 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) def apply_mask_to_padding_states(self, hidden_states, attention_mask): - """ - Applies the mask to hidden_states like Torch reference. - If attention_mask is not None and has more than 1 row and column, multiply and cast. - """ if attention_mask is not None: if attention_mask.shape[0] > 1 and attention_mask.shape[1] > 1: dtype = hidden_states.dtype @@ -271,49 +275,96 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): ) def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): - # Apply attention mask to inputs, if provided and valid - if mask is not None: - # mask shape (B, 1, 1, S) or similar - mask_2d = mask.squeeze(1).squeeze(1) - else: - mask_2d = None + # Derive a 2D (B, S) mask for padding if an array was provided; ignore string sentinels like 'causal'. + mask_2d = None + if isinstance(mask, mx.array): + if mask.ndim >= 4: + mask_2d = mask.squeeze(1).squeeze(1) + elif mask.ndim == 2: + mask_2d = mask + inputs = self.apply_mask_to_padding_states(inputs, mask_2d) - B,L,_ = inputs.shape - qkvz, ba = self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) - q,k,v,z,b,a = self.fix_query_key_value_ordering(qkvz, ba) - q,k,v = (x.reshape(B,L,-1) for x in (q,k,v)) - mixed_qkv = mx.concatenate((q,k,v),-1) + B, S, _ = inputs.shape + # Split cache into conv_state and recurrent_state if provided + if cache is not None: + conv_state, recurrent_state = cache + else: + conv_state = None + recurrent_state = None + + # Project to QKVZ and BA then fix ordering + qkvz = self.in_proj_qkvz(inputs) + ba = self.in_proj_ba(inputs) + q, k, v, z, b, a = self.fix_query_key_value_ordering(qkvz, ba) + + # Reshape q, k, v to (B, S, -1) + q = q.reshape(B, S, -1) + k = k.reshape(B, S, -1) + v = v.reshape(B, S, -1) + # Concatenate for conv1d + x_qkv = mx.concatenate([q, k, v], axis=-1) + + # Convolutional state/caching logic if cache is not None: - conv_state, rec_state = cache + # OG Torch: if conv_state is None, allocate zeros if conv_state is None: - conv_state = mx.zeros((B,self.conv_kernel_size-1,self.conv_dim),dtype=inputs.dtype) - padded = mx.concatenate([conv_state,mixed_qkv],1) - cache[0] = padded[:,-(self.conv_kernel_size-1):] - mixed_qkv = nn.silu(self.conv1d(padded)[:,:L]) + conv_state = mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) + # If S == 1, use causal conv update (append new token to cache, run conv1d on last window) + if S == 1: + # "Causal conv1d update" -- append new x_qkv to conv_state and apply conv1d to the window + padded = mx.concatenate([conv_state, x_qkv], axis=1) + # Keep last (kernel_size-1) for next time + cache[0] = padded[:, -(self.conv_kernel_size - 1):] + conv_out = self.conv1d(padded)[:, -S:] # take only the last position + conv_out = nn.silu(conv_out) + else: + # For sequence, pad and apply conv1d + padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) + conv_out = self.conv1d(padded)[:, :S] + conv_out = nn.silu(conv_out) + cache[0] = x_qkv[:, -(self.conv_kernel_size - 1):] # update cache else: - padded = mx.pad(mixed_qkv,[(0,0),(self.conv_kernel_size-1,0),(0,0)]) - mixed_qkv = nn.silu(self.conv1d(padded)[:,:L]); rec_state=None - - q,k,v = mx.split(mixed_qkv,[self.key_dim,2*self.key_dim],-1) - q = q.reshape(B,L,-1,self.head_k_dim); k = k.reshape(B,L,-1,self.head_k_dim); v = v.reshape(B,L,-1,self.head_v_dim) - - beta = mx.sigmoid(b).reshape(B,L,-1,1) - g = (-mx.exp(self.A_log.astype(mx.float32))*nn.softplus(a.astype(mx.float32)+self.dt_bias)).reshape(B,L,-1,1) - if self.num_v_heads//self.num_k_heads>1: - q = mx.repeat(q,self.num_v_heads//self.num_k_heads,axis=2) - k = mx.repeat(k,self.num_v_heads//self.num_k_heads,axis=2) - - if rec_state is None: - rec_state = mx.zeros((B,self.num_v_heads,self.head_k_dim,self.head_v_dim),dtype=inputs.dtype) - out,new_state = recurrent_gated_delta_rule(q,k,v,g,beta,rec_state,cache is not None,True) - - if cache is not None: cache[1]=new_state - else: new_state=None + # No cache: pad and apply conv1d + padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) + conv_out = self.conv1d(padded)[:, :S] + conv_out = nn.silu(conv_out) + recurrent_state = None + + # Split conv_out back to q, k, v and reshape to (B, S, heads, head_dim) + q_c, k_c, v_c = mx.split(conv_out, [self.key_dim, 2 * self.key_dim], axis=-1) + q = q_c.reshape(B, S, self.num_k_heads, self.head_k_dim) + k = k_c.reshape(B, S, self.num_k_heads, self.head_k_dim) + v = v_c.reshape(B, S, self.num_v_heads, self.head_v_dim) + + # beta = sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) + beta = mx.sigmoid(b) + g = -mx.exp(self.A_log.astype(mx.float32)) * nn.softplus(a.astype(mx.float32) + self.dt_bias) + # No .reshape(...,1): keep as (B,S,H) + + # If num_v_heads > num_k_heads, repeat q/k accordingly (along the heads axis) + if self.num_v_heads // self.num_k_heads > 1: + repeat_factor = self.num_v_heads // self.num_k_heads + q = mx.repeat(q, repeat_factor, axis=2) + k = mx.repeat(k, repeat_factor, axis=2) + + # Choose the recurrent rule: if step size is 1 and recurrent_state exists, use recurrent_gated_delta_rule, else same (no chunked variant in MLX) + if recurrent_state is None: + recurrent_state = mx.zeros((B, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=inputs.dtype) + # (B, S, H, Dk), (B, S, H, Dk), (B, S, H, Dv), (B, S, H), (B, S, H) + out, new_state = recurrent_gated_delta_rule( + q, k, v, g, beta, recurrent_state, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + # Update cache recurrent_state if cache exists + if cache is not None: + cache[1] = new_state - out = self.norm(out.reshape(-1,out.shape[-1]),z.reshape(-1,z.shape[-1])).reshape(z.shape[0],z.shape[1],-1) - return self.out_proj(out) + # Apply norm on (B*S, dim), then reshape back (B, S, -1) + out_reshaped = out.reshape(-1, out.shape[-1]) + z_reshaped = z.reshape(-1, z.shape[-1]) + core_out = self.norm(out_reshaped, z_reshaped).reshape(z.shape[0], z.shape[1], -1) + return self.out_proj(core_out) class Qwen3NextSparseMoeBlock(nn.Module): @@ -387,8 +438,7 @@ def __call__( elif self.layer_type == "full_attention": r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r + out = h + self.mlp(self.post_attention_layernorm(h)) return out @@ -412,25 +462,14 @@ def __call__( if cache is None: cache = [None] * len(self.layers) - cache_counter = 0 - for layer in self.layers: - if layer.layer_type == "linear_attention": - c = cache[cache_counter] - cache_counter += 1 - elif layer.layer_type == "full_attention": - c = cache[cache_counter] - cache_counter += 1 - else: - c = None + causal_mask = create_attention_mask(hidden_states, cache) + linear_mask = None + if cache and len(cache) > 0: + linear_mask = cache[0] - # Compute attention mask per layer as needed - if layer.layer_type == "full_attention": - mask_to_use = create_attention_mask(hidden_states, c) - elif layer.layer_type == "linear_attention": - mask_to_use = None - else: - mask_to_use = None - hidden_states = layer(hidden_states, mask=mask_to_use, cache=c) + for i, layer in enumerate(self.layers): + layer_mask = linear_mask if getattr(layer, "layer_type", None) == "linear_attention" else causal_mask + hidden_states = layer(hidden_states, mask=layer_mask, cache=cache[i]) return self.norm(hidden_states) @@ -441,7 +480,8 @@ def __init__(self, args: ModelArgs): self.args = args self.model_type = args.model_type self.model = Qwen3NextModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -449,7 +489,11 @@ def __call__( cache: Optional[Any] = None, ): out = self.model(inputs, cache) - return self.lm_head(out) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @property def layers(self): @@ -479,4 +523,18 @@ def sanitize(self, weights): for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + norm_keys = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + ".norm.weight", + ".q_norm.weight", + ".k_norm.weight", + ) + for k in list(weights.keys()): + if any(sfx in k for sfx in norm_keys): + v = weights[k] + if len(v.shape) == 1: + weights[k] = v + 1.0 return weights \ No newline at end of file From 7bf6f8aa53dab1c82b290215ef00ecbefcf533f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:37:09 +0200 Subject: [PATCH 22/44] fix inference --- mlx_lm/models/qwen3_next.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 03d668a12..1c631114c 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -135,10 +135,29 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.eps = eps self.weight = mx.ones(hidden_size) - def __call__(self, hidden_states: mx.array, gate: mx.array = None) -> mx.array: + def __call__(self, hidden_states: mx.array, gate: Optional[mx.array] = None) -> mx.array: + input_dtype = hidden_states.dtype + x = hidden_states.astype(mx.float32) + variance = mx.mean(mx.square(x), axis=-1, keepdims=True) + x = x * mx.rsqrt(variance + self.eps) + x = (1.0 + self.weight.astype(mx.float32)) * x if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - return mx.fast.rms_norm(hidden_states, 1.0 + self.weight, self.eps) + x = x * nn.silu(gate.astype(mx.float32)) + return x.astype(input_dtype) + + +class Qwen3NextRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = mx.zeros(dim) + + def __call__(self, x: mx.array) -> mx.array: + x = x.astype(mx.float32) + variance = mx.mean(mx.square(x), axis=-1, keepdims=True) + output = x * mx.rsqrt(variance + self.eps) + output = output * (1.0 + self.weight.astype(mx.float32)) + return output.astype(x.dtype) class Qwen3NextAttention(nn.Module): @@ -154,8 +173,8 @@ def __init__(self, args: ModelArgs): self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=args.attention_bias) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, args.hidden_size, bias=args.attention_bias) - self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) self.rope = initialize_rope( self.head_dim, @@ -414,8 +433,8 @@ def __init__(self, args: ModelArgs, layer_idx: int): elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( + self.input_layernorm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.args = args @@ -450,7 +469,7 @@ def __init__(self, args: ModelArgs): Qwen3NextDecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers) ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, From daf6f0bd275caac1f64cdcb65875ffe233e35da5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:47:23 +0200 Subject: [PATCH 23/44] fix --- mlx_lm/models/qwen3_next.py | 94 +++++++------------------------------ 1 file changed, 18 insertions(+), 76 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 1c631114c..4b5b13d61 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -8,15 +8,10 @@ from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .rope_utils import initialize_rope -from .cache import KVCache, ArraysCache +from .cache import KVCache, MambaCache from .switch_layers import SwitchGLU -class MambaCache(ArraysCache): - def __init__(self): - super().__init__(size=2) - - @dataclass class ModelArgs(BaseModelArgs): model_type: str @@ -55,7 +50,6 @@ def recurrent_gated_delta_rule( g: mx.array, beta: mx.array, initial_state: Optional[mx.array] = None, - output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[mx.array, Optional[mx.array]]: """Minimal recurrent gated delta rule in MLX matching the Torch reference. @@ -124,9 +118,7 @@ def recurrent_gated_delta_rule( # Return (B, H, S, Dv) like Torch's out.transpose(1, 2) out = mx.transpose(out, (0, 2, 1, 3)).astype(orig_dtype) - if not output_final_state: - state = None - return out, state + return out class Qwen3NextRMSNormGated(nn.Module): @@ -135,15 +127,11 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.eps = eps self.weight = mx.ones(hidden_size) - def __call__(self, hidden_states: mx.array, gate: Optional[mx.array] = None) -> mx.array: - input_dtype = hidden_states.dtype - x = hidden_states.astype(mx.float32) - variance = mx.mean(mx.square(x), axis=-1, keepdims=True) - x = x * mx.rsqrt(variance + self.eps) - x = (1.0 + self.weight.astype(mx.float32)) * x + def __call__(self, hidden_states: mx.array, gate: mx.array | None = None) -> mx.array: + x = mx.fast.rms_norm(hidden_states, None, self.eps) * (1.0 + self.weight) if gate is not None: - x = x * nn.silu(gate.astype(mx.float32)) - return x.astype(input_dtype) + x = x * nn.silu(gate) + return x class Qwen3NextRMSNorm(nn.Module): @@ -151,13 +139,9 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = mx.zeros(dim) - + def __call__(self, x: mx.array) -> mx.array: - x = x.astype(mx.float32) - variance = mx.mean(mx.square(x), axis=-1, keepdims=True) - output = x * mx.rsqrt(variance + self.eps) - output = output * (1.0 + self.weight.astype(mx.float32)) - return output.astype(x.dtype) + return mx.fast.rms_norm(x, None, self.eps) * (1.0 + self.weight) class Qwen3NextAttention(nn.Module): @@ -256,10 +240,8 @@ def __init__(self, config: ModelArgs): padding=self.conv_kernel_size - 1, ) - projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - projection_size_ba = self.num_v_heads * 2 - self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False) - self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False) + self.in_proj_qkvz = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim * 2, bias=False) + self.in_proj_ba = nn.Linear(self.hidden_size, self.num_v_heads * 2, bias=False) self.dt_bias = mx.ones(self.num_v_heads) @@ -270,14 +252,6 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) - def apply_mask_to_padding_states(self, hidden_states, attention_mask): - if attention_mask is not None: - if attention_mask.shape[0] > 1 and attention_mask.shape[1] > 1: - dtype = hidden_states.dtype - # Mask shape: (B, S), hidden_states: (B, S, ...) - hidden_states = (hidden_states * attention_mask[:, :, None]).astype(dtype) - return hidden_states - def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): nq, nk, nv, dv = self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nq, 2*nk + 2*nv*dv//nq) @@ -294,16 +268,6 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): ) def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): - # Derive a 2D (B, S) mask for padding if an array was provided; ignore string sentinels like 'causal'. - mask_2d = None - if isinstance(mask, mx.array): - if mask.ndim >= 4: - mask_2d = mask.squeeze(1).squeeze(1) - elif mask.ndim == 2: - mask_2d = mask - - inputs = self.apply_mask_to_padding_states(inputs, mask_2d) - B, S, _ = inputs.shape # Split cache into conv_state and recurrent_state if provided @@ -327,25 +291,11 @@ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Opt # Convolutional state/caching logic if cache is not None: - # OG Torch: if conv_state is None, allocate zeros - if conv_state is None: - conv_state = mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) - # If S == 1, use causal conv update (append new token to cache, run conv1d on last window) - if S == 1: - # "Causal conv1d update" -- append new x_qkv to conv_state and apply conv1d to the window - padded = mx.concatenate([conv_state, x_qkv], axis=1) - # Keep last (kernel_size-1) for next time - cache[0] = padded[:, -(self.conv_kernel_size - 1):] - conv_out = self.conv1d(padded)[:, -S:] # take only the last position - conv_out = nn.silu(conv_out) - else: - # For sequence, pad and apply conv1d - padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_out = self.conv1d(padded)[:, :S] - conv_out = nn.silu(conv_out) - cache[0] = x_qkv[:, -(self.conv_kernel_size - 1):] # update cache + padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) + conv_out = self.conv1d(padded)[:, :S] + conv_out = nn.silu(conv_out) + cache[0] = x_qkv[:, -(self.conv_kernel_size - 1):] else: - # No cache: pad and apply conv1d padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) conv_out = self.conv1d(padded)[:, :S] conv_out = nn.silu(conv_out) @@ -359,7 +309,7 @@ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Opt # beta = sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) beta = mx.sigmoid(b) - g = -mx.exp(self.A_log.astype(mx.float32)) * nn.softplus(a.astype(mx.float32) + self.dt_bias) + g = -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) # No .reshape(...,1): keep as (B,S,H) # If num_v_heads > num_k_heads, repeat q/k accordingly (along the heads axis) @@ -372,13 +322,9 @@ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Opt if recurrent_state is None: recurrent_state = mx.zeros((B, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=inputs.dtype) # (B, S, H, Dk), (B, S, H, Dk), (B, S, H, Dv), (B, S, H), (B, S, H) - out, new_state = recurrent_gated_delta_rule( - q, k, v, g, beta, recurrent_state, output_final_state=True, use_qk_l2norm_in_kernel=True + out = recurrent_gated_delta_rule( + q, k, v, g, beta, recurrent_state, use_qk_l2norm_in_kernel=True ) - # Update cache recurrent_state if cache exists - if cache is not None: - cache[1] = new_state - # Apply norm on (B*S, dim), then reshape back (B, S, -1) out_reshaped = out.reshape(-1, out.shape[-1]) z_reshaped = z.reshape(-1, z.shape[-1]) @@ -482,13 +428,9 @@ def __call__( cache = [None] * len(self.layers) causal_mask = create_attention_mask(hidden_states, cache) - linear_mask = None - if cache and len(cache) > 0: - linear_mask = cache[0] for i, layer in enumerate(self.layers): - layer_mask = linear_mask if getattr(layer, "layer_type", None) == "linear_attention" else causal_mask - hidden_states = layer(hidden_states, mask=layer_mask, cache=cache[i]) + hidden_states = layer(hidden_states, mask=causal_mask, cache=cache[i]) return self.norm(hidden_states) From 1d952a45e0504a25bb647a6768310374c3015b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:05:23 +0200 Subject: [PATCH 24/44] optimization --- mlx_lm/models/qwen3_next.py | 103 +++++++++--------------------------- 1 file changed, 25 insertions(+), 78 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 4b5b13d61..1e5eec9f7 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -8,10 +8,15 @@ from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .rope_utils import initialize_rope -from .cache import KVCache, MambaCache +from .cache import KVCache, ArraysCache from .switch_layers import SwitchGLU +class MambaCache(ArraysCache): + def __init__(self): + super().__init__(size=2) + + @dataclass class ModelArgs(BaseModelArgs): model_type: str @@ -51,74 +56,30 @@ def recurrent_gated_delta_rule( beta: mx.array, initial_state: Optional[mx.array] = None, use_qk_l2norm_in_kernel: bool = False, -) -> Tuple[mx.array, Optional[mx.array]]: - """Minimal recurrent gated delta rule in MLX matching the Torch reference. - Expects query/key/value shapes (B, S, H, D*) and g/beta shapes (B, S, H) or (B, S, H, 1). - """ - orig_dtype = query.dtype - # Optional L2 normalization on last dim for query/key +): + """Minimal recurrent gated delta rule in MLX, same inputs/outputs as Torch.""" if use_qk_l2norm_in_kernel: - # Normalize along the feature dimension - query = query / mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) - key = key / mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) - - # Cast to float32 for numerical stability (like Torch .to(torch.float32)) - query = query.astype(mx.float32) - key = key.astype(mx.float32) - value = value.astype(mx.float32) - beta = beta.astype(mx.float32) - g = g.astype(mx.float32) - - # Allow beta and g to come with an extra trailing singleton dim: (B,S,H,1) - if beta.ndim == 4 and beta.shape[-1] == 1: - beta = beta.squeeze(-1) - if g.ndim == 4 and g.shape[-1] == 1: - g = g.squeeze(-1) + query /= mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) + key /= mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) + + if beta.ndim == 4: beta = beta.squeeze(-1) + if g.ndim == 4: g = g.squeeze(-1) B, S, H, Dk = key.shape Dv = value.shape[-1] + query *= 1.0 / mx.sqrt(mx.array(query.shape[-1])) - # Scale queries by 1/sqrt(Dq) (Dq == last dim of query) - scale = 1.0 / mx.sqrt(mx.array(query.shape[-1], dtype=mx.float32)) - query = query * scale - - # Precompute value*beta and key*beta to match the Torch reference - v_beta = value * beta[..., None] - k_beta = key * beta[..., None] - - # Initialize state: (B, H, Dk, Dv) - if initial_state is None: - state = mx.zeros((B, H, Dk, Dv), dtype=value.dtype) - else: - state = initial_state.astype(value.dtype) - if state.shape != (B, H, Dk, Dv): - state = state.reshape(B, H, Dk, Dv) - - # Output buffer: (B, S, H, Dv) - out = mx.zeros((B, S, H, Dv), dtype=value.dtype) + v_beta, k_beta = value * beta[..., None], key * beta[..., None] + state = mx.zeros((B, H, Dk, Dv), dtype=value.dtype) if initial_state is None else initial_state.reshape(B, H, Dk, Dv) + out = mx.zeros((B, S, H, Dv), dtype=value.dtype) for t in range(S): - q_t = query[:, t] # (B, H, Dk) - k_t = k_beta[:, t] # (B, H, Dk) - v_t = v_beta[:, t] # (B, H, Dv) - g_t = g[:, t] # (B, H) - - # decay = exp(g_t) - decay = mx.exp(g_t)[..., None] # (B, H, 1) - - # state = state * decay.unsqueeze(-1) + k_t.unsqueeze(-1) @ v_t.unsqueeze(-2) - state = state * decay[..., None] + mx.matmul( - k_t[..., None], # (B, H, Dk, 1) - v_t[..., None, :], # (B, H, 1, Dv) + state = state * mx.exp(g[:, t])[..., None, None] + mx.matmul( + k_beta[:, t][..., None], v_beta[:, t][..., None, :] ) + out[:, t] = mx.einsum("bhd,bhdv->bhv", query[:, t], state) - # out[:, t] = einsum("bhd,bhdv->bhv", q_t, state) - out[:, t] = mx.einsum("bhd,bhdv->bhv", q_t, state) - - # Return (B, H, S, Dv) like Torch's out.transpose(1, 2) - out = mx.transpose(out, (0, 2, 1, 3)).astype(orig_dtype) - - return out + return mx.transpose(out, (0, 2, 1, 3)) class Qwen3NextRMSNormGated(nn.Module): @@ -269,14 +230,6 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): B, S, _ = inputs.shape - - # Split cache into conv_state and recurrent_state if provided - if cache is not None: - conv_state, recurrent_state = cache - else: - conv_state = None - recurrent_state = None - # Project to QKVZ and BA then fix ordering qkvz = self.in_proj_qkvz(inputs) ba = self.in_proj_ba(inputs) @@ -290,16 +243,10 @@ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Opt x_qkv = mx.concatenate([q, k, v], axis=-1) # Convolutional state/caching logic - if cache is not None: - padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_out = self.conv1d(padded)[:, :S] - conv_out = nn.silu(conv_out) - cache[0] = x_qkv[:, -(self.conv_kernel_size - 1):] - else: - padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_out = self.conv1d(padded)[:, :S] - conv_out = nn.silu(conv_out) - recurrent_state = None + padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) + conv_out = self.conv1d(padded)[:, :S] + conv_out = nn.silu(conv_out) + recurrent_state = None # Split conv_out back to q, k, v and reshape to (B, S, heads, head_dim) q_c, k_c, v_c = mx.split(conv_out, [self.key_dim, 2 * self.key_dim], axis=-1) From 21afa6005ca757c642d6cbd10a39cfb7da8e94d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:06:22 +0200 Subject: [PATCH 25/44] nits --- mlx_lm/models/qwen3_next.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 1e5eec9f7..0dc2264a0 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -57,7 +57,6 @@ def recurrent_gated_delta_rule( initial_state: Optional[mx.array] = None, use_qk_l2norm_in_kernel: bool = False, ): - """Minimal recurrent gated delta rule in MLX, same inputs/outputs as Torch.""" if use_qk_l2norm_in_kernel: query /= mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) key /= mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) From 1d078115615615cfc47da847c69e485c68c107e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:23:13 +0200 Subject: [PATCH 26/44] minimize --- mlx_lm/models/qwen3_next.py | 69 +++++++++++++------------------------ 1 file changed, 24 insertions(+), 45 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 0dc2264a0..fbed88cde 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -229,53 +229,32 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): B, S, _ = inputs.shape - # Project to QKVZ and BA then fix ordering - qkvz = self.in_proj_qkvz(inputs) - ba = self.in_proj_ba(inputs) - q, k, v, z, b, a = self.fix_query_key_value_ordering(qkvz, ba) - - # Reshape q, k, v to (B, S, -1) - q = q.reshape(B, S, -1) - k = k.reshape(B, S, -1) - v = v.reshape(B, S, -1) - # Concatenate for conv1d - x_qkv = mx.concatenate([q, k, v], axis=-1) - - # Convolutional state/caching logic - padded = mx.pad(x_qkv, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]) - conv_out = self.conv1d(padded)[:, :S] - conv_out = nn.silu(conv_out) - recurrent_state = None - - # Split conv_out back to q, k, v and reshape to (B, S, heads, head_dim) - q_c, k_c, v_c = mx.split(conv_out, [self.key_dim, 2 * self.key_dim], axis=-1) - q = q_c.reshape(B, S, self.num_k_heads, self.head_k_dim) - k = k_c.reshape(B, S, self.num_k_heads, self.head_k_dim) - v = v_c.reshape(B, S, self.num_v_heads, self.head_v_dim) - - # beta = sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) - beta = mx.sigmoid(b) - g = -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) - # No .reshape(...,1): keep as (B,S,H) - - # If num_v_heads > num_k_heads, repeat q/k accordingly (along the heads axis) - if self.num_v_heads // self.num_k_heads > 1: - repeat_factor = self.num_v_heads // self.num_k_heads - q = mx.repeat(q, repeat_factor, axis=2) - k = mx.repeat(k, repeat_factor, axis=2) - - # Choose the recurrent rule: if step size is 1 and recurrent_state exists, use recurrent_gated_delta_rule, else same (no chunked variant in MLX) - if recurrent_state is None: - recurrent_state = mx.zeros((B, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=inputs.dtype) - # (B, S, H, Dk), (B, S, H, Dk), (B, S, H, Dv), (B, S, H), (B, S, H) + q, k, v, z, b, a = self.fix_query_key_value_ordering(self.in_proj_qkvz(inputs), self.in_proj_ba(inputs)) + + # Conv1d on concatenated q/k/v + conv_out = nn.silu(self.conv1d(mx.pad(mx.concatenate( + [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], -1 + ), [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]))[:, :S]) + + q, k, v = [t.reshape(B, S, h, d) for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim] + )] + + beta, g = mx.sigmoid(b), -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) + + if self.num_v_heads > self.num_k_heads: + f = self.num_v_heads // self.num_k_heads + q, k = mx.repeat(q, f, 2), mx.repeat(k, f, 2) + out = recurrent_gated_delta_rule( - q, k, v, g, beta, recurrent_state, use_qk_l2norm_in_kernel=True + q, k, v, g, beta, + mx.zeros((B, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=inputs.dtype), + use_qk_l2norm_in_kernel=True ) - # Apply norm on (B*S, dim), then reshape back (B, S, -1) - out_reshaped = out.reshape(-1, out.shape[-1]) - z_reshaped = z.reshape(-1, z.shape[-1]) - core_out = self.norm(out_reshaped, z_reshaped).reshape(z.shape[0], z.shape[1], -1) - return self.out_proj(core_out) + out = self.norm(out.reshape(-1, out.shape[-1]), z.reshape(-1, z.shape[-1])) + return self.out_proj(out.reshape(B, S, -1)) class Qwen3NextSparseMoeBlock(nn.Module): From 12560cde46ae575480abc73f64c07bf3e27bdd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:24:04 +0200 Subject: [PATCH 27/44] clean ups --- mlx_lm/models/qwen3_next.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index fbed88cde..beef8abec 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -160,9 +160,7 @@ def __call__( ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - output = output * mx.sigmoid(gate) - - return self.o_proj(output) + return self.o_proj(output * mx.sigmoid(gate)) class Qwen3NextMLP(nn.Module): From 65f4250663bc62258a6cf603cf0223d686efa734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:24:18 +0200 Subject: [PATCH 28/44] format --- mlx_lm/models/qwen3_next.py | 158 +++++++++++++++++++++++++----------- 1 file changed, 111 insertions(+), 47 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index beef8abec..89e66d83c 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -1,14 +1,14 @@ # Copyright © 2025 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import ArraysCache, KVCache from .rope_utils import initialize_rope -from .cache import KVCache, ArraysCache from .switch_layers import SwitchGLU @@ -59,18 +59,24 @@ def recurrent_gated_delta_rule( ): if use_qk_l2norm_in_kernel: query /= mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) - key /= mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) + key /= mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) - if beta.ndim == 4: beta = beta.squeeze(-1) - if g.ndim == 4: g = g.squeeze(-1) + if beta.ndim == 4: + beta = beta.squeeze(-1) + if g.ndim == 4: + g = g.squeeze(-1) B, S, H, Dk = key.shape Dv = value.shape[-1] query *= 1.0 / mx.sqrt(mx.array(query.shape[-1])) v_beta, k_beta = value * beta[..., None], key * beta[..., None] - state = mx.zeros((B, H, Dk, Dv), dtype=value.dtype) if initial_state is None else initial_state.reshape(B, H, Dk, Dv) - out = mx.zeros((B, S, H, Dv), dtype=value.dtype) + state = ( + mx.zeros((B, H, Dk, Dv), dtype=value.dtype) + if initial_state is None + else initial_state.reshape(B, H, Dk, Dv) + ) + out = mx.zeros((B, S, H, Dv), dtype=value.dtype) for t in range(S): state = state * mx.exp(g[:, t])[..., None, None] + mx.matmul( @@ -87,7 +93,9 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.eps = eps self.weight = mx.ones(hidden_size) - def __call__(self, hidden_states: mx.array, gate: mx.array | None = None) -> mx.array: + def __call__( + self, hidden_states: mx.array, gate: mx.array | None = None + ) -> mx.array: x = mx.fast.rms_norm(hidden_states, None, self.eps) * (1.0 + self.weight) if gate is not None: x = x * nn.silu(gate) @@ -112,10 +120,26 @@ def __init__(self, args: ModelArgs): self.head_dim = args.hidden_size // self.num_attention_heads self.scale = self.head_dim**-0.5 - self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * self.head_dim * 2, bias=args.attention_bias) - self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=args.attention_bias) - self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=args.attention_bias) - self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, args.hidden_size, bias=args.attention_bias) + self.q_proj = nn.Linear( + args.hidden_size, + self.num_attention_heads * self.head_dim * 2, + bias=args.attention_bias, + ) + self.k_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.v_proj = nn.Linear( + args.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, + args.hidden_size, + bias=args.attention_bias, + ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) @@ -135,18 +159,24 @@ def __call__( cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape - + q_proj_output = self.q_proj(x) - queries, gate = mx.split(q_proj_output.reshape(B, L, self.num_attention_heads, -1, 2), 2, axis=-1) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1, 2), 2, axis=-1 + ) queries = queries.squeeze(-1) gate = gate.squeeze(-1).reshape(B, L, -1) - + keys, values = self.k_proj(x), self.v_proj(x) - + queries = self.q_norm(queries).transpose(0, 2, 1, 3) - keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) @@ -154,12 +184,12 @@ def __call__( else: queries = self.rope(queries) keys = self.rope(keys) - + output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - + return self.o_proj(output * mx.sigmoid(gate)) @@ -198,7 +228,9 @@ def __init__(self, config: ModelArgs): padding=self.conv_kernel_size - 1, ) - self.in_proj_qkvz = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim * 2, bias=False) + self.in_proj_qkvz = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim * 2, bias=False + ) self.in_proj_ba = nn.Linear(self.hidden_size, self.num_v_heads * 2, bias=False) self.dt_bias = mx.ones(self.num_v_heads) @@ -209,13 +241,20 @@ def __init__(self, config: ModelArgs): self.norm = Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) - + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): - nq, nk, nv, dv = self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim - mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nq, 2*nk + 2*nv*dv//nq) - mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nq, 2*nv//nq) - q,k,v,z = mx.split(mixed_qkvz,[nk,2*nk,2*nk+nv//nq*dv],axis=-1) - b,a = mx.split(mixed_ba,[nv//nq],axis=-1) + nq, nk, nv, dv = ( + self.num_k_heads, + self.head_k_dim, + self.num_v_heads, + self.head_v_dim, + ) + mixed_qkvz = mixed_qkvz.reshape( + *mixed_qkvz.shape[:-1], nq, 2 * nk + 2 * nv * dv // nq + ) + mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nq, 2 * nv // nq) + q, k, v, z = mx.split(mixed_qkvz, [nk, 2 * nk, 2 * nk + nv // nq * dv], axis=-1) + b, a = mx.split(mixed_ba, [nv // nq], axis=-1) return ( q, k, @@ -224,21 +263,39 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): b.reshape(b.shape[0], b.shape[1], nv), a.reshape(a.shape[0], a.shape[1], nv), ) - - def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None): + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): B, S, _ = inputs.shape - q, k, v, z, b, a = self.fix_query_key_value_ordering(self.in_proj_qkvz(inputs), self.in_proj_ba(inputs)) + q, k, v, z, b, a = self.fix_query_key_value_ordering( + self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) + ) # Conv1d on concatenated q/k/v - conv_out = nn.silu(self.conv1d(mx.pad(mx.concatenate( - [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], -1 - ), [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]))[:, :S]) + conv_out = nn.silu( + self.conv1d( + mx.pad( + mx.concatenate( + [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], + -1, + ), + [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)], + ) + )[:, :S] + ) - q, k, v = [t.reshape(B, S, h, d) for t, h, d in zip( - mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), - [self.num_k_heads, self.num_k_heads, self.num_v_heads], - [self.head_k_dim, self.head_k_dim, self.head_v_dim] - )] + q, k, v = [ + t.reshape(B, S, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] beta, g = mx.sigmoid(b), -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) @@ -247,9 +304,16 @@ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Opt q, k = mx.repeat(q, f, 2), mx.repeat(k, f, 2) out = recurrent_gated_delta_rule( - q, k, v, g, beta, - mx.zeros((B, self.num_v_heads, self.head_k_dim, self.head_v_dim), dtype=inputs.dtype), - use_qk_l2norm_in_kernel=True + q, + k, + v, + g, + beta, + mx.zeros( + (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=inputs.dtype, + ), + use_qk_l2norm_in_kernel=True, ) out = self.norm(out.reshape(-1, out.shape[-1]), z.reshape(-1, z.shape[-1])) return self.out_proj(out.reshape(B, S, -1)) @@ -301,20 +365,20 @@ def __init__(self, args: ModelArgs, layer_idx: int): self.linear_attn = Qwen3NextGatedDeltaNet(args) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention(args) - + self.input_layernorm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = Qwen3NextRMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.args = args - + if (layer_idx not in args.mlp_only_layers) and ( args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 ): self.mlp = Qwen3NextSparseMoeBlock(args) else: self.mlp = Qwen3NextMLP(args.hidden_size, args.intermediate_size) - + def __call__( self, x: mx.array, @@ -391,7 +455,7 @@ def make_cache(self): elif l.layer_type == "full_attention": caches.append(KVCache()) return caches - + def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights @@ -421,4 +485,4 @@ def sanitize(self, weights): v = weights[k] if len(v.shape) == 1: weights[k] = v + 1.0 - return weights \ No newline at end of file + return weights From e42be941647c33ab73517b898cb1be4abc2df06f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:26:32 +0200 Subject: [PATCH 29/44] nits --- mlx_lm/models/qwen3_next.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 89e66d83c..f18fdeec2 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -7,16 +7,11 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import ArraysCache, KVCache +from .cache import MambaCache, KVCache from .rope_utils import initialize_rope from .switch_layers import SwitchGLU -class MambaCache(ArraysCache): - def __init__(self): - super().__init__(size=2) - - @dataclass class ModelArgs(BaseModelArgs): model_type: str From ac55338261469cd9c3e49dcd2a2a311dcd9b8ed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:26:53 +0200 Subject: [PATCH 30/44] format again --- mlx_lm/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index f18fdeec2..adb18e664 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -7,7 +7,7 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import MambaCache, KVCache +from .cache import KVCache, MambaCache from .rope_utils import initialize_rope from .switch_layers import SwitchGLU From fa2e5c4bebdbac13ec829a714d04ae2cdd7bf5d4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 11 Sep 2025 18:51:49 +0200 Subject: [PATCH 31/44] set some defaults --- mlx_lm/models/qwen3_next.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index adb18e664..4b65e8799 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -34,10 +34,10 @@ class ModelArgs(BaseModelArgs): vocab_size: int num_key_value_heads: int rope_theta: float - tie_word_embeddings: bool max_position_embeddings: int - norm_topk_prob: bool - attention_bias: bool + norm_topk_prob: bool = False + tie_word_embeddings: bool = False + attention_bias: bool = False layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None From e1f104e3063f33d830b0f46b80cf9b94088d0a8f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 11 Sep 2025 18:54:41 +0200 Subject: [PATCH 32/44] alternateing layer defaults --- mlx_lm/models/qwen3_next.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 4b65e8799..b8d9323f1 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -40,6 +40,14 @@ class ModelArgs(BaseModelArgs): attention_bias: bool = False layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None + full_attention_interval: int = 4 + + def __post_init__(self): + if self.layer_types is None: + self.layer_types = [ + "linear_attention" if (i + 1) % self.full_attention_interval else "full_attention" + for i in range(self.num_hidden_layers) + ] @mx.compile From 7d248a14dd4d8749773ebb87754f7ea40e1637fe Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 11 Sep 2025 18:57:47 +0200 Subject: [PATCH 33/44] remove MTP layers --- mlx_lm/models/qwen3_next.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index b8d9323f1..04f644bb2 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -488,4 +488,9 @@ def sanitize(self, weights): v = weights[k] if len(v.shape) == 1: weights[k] = v + 1.0 + weights = { + key: value + for key, value in weights.items() + if "mtp." not in key + } return weights From 06a97ed54410fdaca9f9adc283cf09151d27aaa6 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 11 Sep 2025 18:59:38 +0200 Subject: [PATCH 34/44] add head dim but optional --- mlx_lm/models/qwen3_next.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 04f644bb2..8d2320a35 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -38,6 +38,7 @@ class ModelArgs(BaseModelArgs): norm_topk_prob: bool = False tie_word_embeddings: bool = False attention_bias: bool = False + head_dim: Optional[int] = None layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None full_attention_interval: int = 4 @@ -48,6 +49,8 @@ def __post_init__(self): "linear_attention" if (i + 1) % self.full_attention_interval else "full_attention" for i in range(self.num_hidden_layers) ] + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads @mx.compile @@ -120,7 +123,7 @@ def __init__(self, args: ModelArgs): super().__init__() self.num_key_value_heads = args.num_key_value_heads self.num_attention_heads = args.num_attention_heads - self.head_dim = args.hidden_size // self.num_attention_heads + self.head_dim = args.head_dim self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear( From 6e30a19346f59e720a9ed037631604797b12939f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 11 Sep 2025 19:12:33 +0200 Subject: [PATCH 35/44] nits + format --- mlx_lm/models/qwen3_next.py | 40 +++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 8d2320a35..693710683 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -46,7 +46,11 @@ class ModelArgs(BaseModelArgs): def __post_init__(self): if self.layer_types is None: self.layer_types = [ - "linear_attention" if (i + 1) % self.full_attention_interval else "full_attention" + ( + "linear_attention" + if (i + 1) % self.full_attention_interval + else "full_attention" + ) for i in range(self.num_hidden_layers) ] if self.head_dim is None: @@ -90,7 +94,8 @@ def recurrent_gated_delta_rule( ) out[:, t] = mx.einsum("bhd,bhdv->bhv", query[:, t], state) - return mx.transpose(out, (0, 2, 1, 3)) + out_transposed = mx.transpose(out, (0, 2, 1, 3)) + return out_transposed, state class Qwen3NextRMSNormGated(nn.Module): @@ -281,7 +286,6 @@ def __call__( self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) ) - # Conv1d on concatenated q/k/v conv_out = nn.silu( self.conv1d( mx.pad( @@ -309,18 +313,28 @@ def __call__( f = self.num_v_heads // self.num_k_heads q, k = mx.repeat(q, f, 2), mx.repeat(k, f, 2) - out = recurrent_gated_delta_rule( + initial_state = None + if cache is not None: + initial_state = cache[0] + out, new_state = recurrent_gated_delta_rule( q, k, v, g, beta, - mx.zeros( - (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), - dtype=inputs.dtype, + ( + initial_state + if initial_state is not None + else mx.zeros( + (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=inputs.dtype, + ) ), use_qk_l2norm_in_kernel=True, ) + + if cache is not None: + cache[0] = new_state out = self.norm(out.reshape(-1, out.shape[-1]), z.reshape(-1, z.shape[-1])) return self.out_proj(out.reshape(B, S, -1)) @@ -423,7 +437,11 @@ def __call__( causal_mask = create_attention_mask(hidden_states, cache) for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, mask=causal_mask, cache=cache[i]) + hidden_states = layer( + hidden_states, + mask=causal_mask, + cache=cache[i] if cache is not None else None, + ) return self.norm(hidden_states) @@ -491,9 +509,5 @@ def sanitize(self, weights): v = weights[k] if len(v.shape) == 1: weights[k] = v + 1.0 - weights = { - key: value - for key, value in weights.items() - if "mtp." not in key - } + weights = {key: value for key, value in weights.items() if "mtp." not in key} return weights From 605c4c689839780599db8fdf24df693fe32f6899 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 11 Sep 2025 12:59:16 -0700 Subject: [PATCH 36/44] some nits --- mlx_lm/models/qwen3_next.py | 96 +++++++++++++++---------------------- 1 file changed, 39 insertions(+), 57 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 693710683..0ff260786 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -39,23 +39,9 @@ class ModelArgs(BaseModelArgs): tie_word_embeddings: bool = False attention_bias: bool = False head_dim: Optional[int] = None - layer_types: Optional[List[str]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None full_attention_interval: int = 4 - def __post_init__(self): - if self.layer_types is None: - self.layer_types = [ - ( - "linear_attention" - if (i + 1) % self.full_attention_interval - else "full_attention" - ) - for i in range(self.num_hidden_layers) - ] - if self.head_dim is None: - self.head_dim = self.hidden_size // self.num_attention_heads - @mx.compile def recurrent_gated_delta_rule( @@ -369,29 +355,25 @@ def __call__( y = self.switch_mlp(x, inds) y = (y * scores[..., None]).sum(axis=-2) - shared_expert_output = self.shared_expert(x) - shared_expert_output = ( - mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - ) + shared_y = self.shared_expert(x) + shared_y = mx.sigmoid(self.shared_expert_gate(x)) * shared_y - return y + shared_expert_output + return y + shared_y class Qwen3NextDecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() - self.layer_type = args.layer_types[layer_idx] - if self.layer_type == "linear_attention": + self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0 + if self.is_linear: self.linear_attn = Qwen3NextGatedDeltaNet(args) - elif self.layer_type == "full_attention": + else: self.self_attn = Qwen3NextAttention(args) self.input_layernorm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = Qwen3NextRMSNorm( args.hidden_size, eps=args.rms_norm_eps ) - self.args = args - if (layer_idx not in args.mlp_only_layers) and ( args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 ): @@ -405,9 +387,9 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: - if self.layer_type == "linear_attention": + if self.is_linear: r = self.linear_attn(self.input_layernorm(x), mask, cache) - elif self.layer_type == "full_attention": + else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r out = h + self.mlp(self.post_attention_layernorm(h)) @@ -423,6 +405,7 @@ def __init__(self, args: ModelArgs): for i in range(args.num_hidden_layers) ] self.norm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fa_idx = args.full_attention_interval - 1 def __call__( self, @@ -434,14 +417,10 @@ def __call__( if cache is None: cache = [None] * len(self.layers) - causal_mask = create_attention_mask(hidden_states, cache) + mask = create_attention_mask(hidden_states, cache[self.fa_idx]) - for i, layer in enumerate(self.layers): - hidden_states = layer( - hidden_states, - mask=causal_mask, - cache=cache[i] if cache is not None else None, - ) + for layer, c in zip(self.layers, cache): + hidden_states = layer(hidden_states, mask=mask, cache=c) return self.norm(hidden_states) @@ -472,31 +451,25 @@ def layers(self): return self.model.layers def make_cache(self): - caches = [] - for l in self.layers: - if l.layer_type == "linear_attention": - caches.append(MambaCache()) - elif l.layer_type == "full_attention": - caches.append(KVCache()) - return caches + return [MambaCache() if l.is_linear else KVCache() for l in self.layers] def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - if f"{prefix}.mlp.experts.0.{n}.weight" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) - for k, v in weights.items(): - if "conv1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) + weights = {key: value for key, value in weights.items() if "mtp." not in key} + if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) + + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}.mlp" + for n in ["up_proj", "down_proj", "gate_proj"]: + to_join = [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + norm_keys = ( ".input_layernorm.weight", ".post_attention_layernorm.weight", @@ -504,10 +477,19 @@ def sanitize(self, weights): ".q_norm.weight", ".k_norm.weight", ) - for k in list(weights.keys()): - if any(sfx in k for sfx in norm_keys): - v = weights[k] - if len(v.shape) == 1: + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + if any(k.endswith(sfx) for sfx in norm_keys): + if v.ndim == 1: weights[k] = v + 1.0 - weights = {key: value for key, value in weights.items() if "mtp." not in key} return weights + + @property + def quant_predicate(self): + def predicate(path, _): + if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): + return {"group_size": 64, "bits": 8} + return True + + return predicate From 39f207dd1d82548bb724e17cce5913f94a33d33a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 11 Sep 2025 14:58:23 -0700 Subject: [PATCH 37/44] some fixes --- mlx_lm/models/qwen3_next.py | 101 +++++++++++++++++------------------- 1 file changed, 47 insertions(+), 54 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 0ff260786..e6c656c6c 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -34,6 +34,7 @@ class ModelArgs(BaseModelArgs): vocab_size: int num_key_value_heads: int rope_theta: float + partial_rotary_factor: float max_position_embeddings: int norm_topk_prob: bool = False tie_word_embeddings: bool = False @@ -50,38 +51,35 @@ def recurrent_gated_delta_rule( value: mx.array, g: mx.array, beta: mx.array, - initial_state: Optional[mx.array] = None, + initial_state: mx.array, use_qk_l2norm_in_kernel: bool = False, ): + B, S, H, Dk = key.shape + inv_scale = Dk**-0.5 if use_qk_l2norm_in_kernel: - query /= mx.maximum(mx.linalg.norm(query, axis=-1, keepdims=True), 1e-12) - key /= mx.maximum(mx.linalg.norm(key, axis=-1, keepdims=True), 1e-12) + query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) + key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) if beta.ndim == 4: beta = beta.squeeze(-1) if g.ndim == 4: g = g.squeeze(-1) - B, S, H, Dk = key.shape Dv = value.shape[-1] - query *= 1.0 / mx.sqrt(mx.array(query.shape[-1])) + query = inv_scale * query v_beta, k_beta = value * beta[..., None], key * beta[..., None] - state = ( - mx.zeros((B, H, Dk, Dv), dtype=value.dtype) - if initial_state is None - else initial_state.reshape(B, H, Dk, Dv) - ) out = mx.zeros((B, S, H, Dv), dtype=value.dtype) + state = initial_state + for t in range(S): state = state * mx.exp(g[:, t])[..., None, None] + mx.matmul( k_beta[:, t][..., None], v_beta[:, t][..., None, :] ) out[:, t] = mx.einsum("bhd,bhdv->bhv", query[:, t], state) - out_transposed = mx.transpose(out, (0, 2, 1, 3)) - return out_transposed, state + return out, state class Qwen3NextRMSNormGated(nn.Module): @@ -93,22 +91,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): def __call__( self, hidden_states: mx.array, gate: mx.array | None = None ) -> mx.array: - x = mx.fast.rms_norm(hidden_states, None, self.eps) * (1.0 + self.weight) + x = mx.fast.rms_norm(hidden_states, self.weight, self.eps) if gate is not None: x = x * nn.silu(gate) return x -class Qwen3NextRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = mx.zeros(dim) - - def __call__(self, x: mx.array) -> mx.array: - return mx.fast.rms_norm(x, None, self.eps) * (1.0 + self.weight) - - class Qwen3NextAttention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -138,11 +126,11 @@ def __init__(self, args: ModelArgs): bias=args.attention_bias, ) - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) self.rope = initialize_rope( - self.head_dim, + int(self.head_dim * args.partial_rotary_factor), base=args.rope_theta, traditional=False, scaling_config=args.rope_scaling, @@ -222,7 +210,7 @@ def __init__(self, config: ModelArgs): bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, - padding=self.conv_kernel_size - 1, + padding=0, ) self.in_proj_qkvz = nn.Linear( @@ -272,17 +260,21 @@ def __call__( self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) ) - conv_out = nn.silu( - self.conv1d( - mx.pad( - mx.concatenate( - [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], - -1, - ), - [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)], - ) - )[:, :S] + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, + ) + + mixed_qkv = mx.concatenate( + [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], axis=-1 ) + conv_input = mx.concatenate([conv_state, mixed_qkv], axis=1) + if cache is not None: + cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) q, k, v = [ t.reshape(B, S, h, d) @@ -299,29 +291,27 @@ def __call__( f = self.num_v_heads // self.num_k_heads q, k = mx.repeat(q, f, 2), mx.repeat(k, f, 2) - initial_state = None - if cache is not None: - initial_state = cache[0] + if cache is not None and cache[1] is not None: + initial_state = cache[1] + else: + initial_state = mx.zeros( + (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=inputs.dtype, + ) + out, new_state = recurrent_gated_delta_rule( q, k, v, g, beta, - ( - initial_state - if initial_state is not None - else mx.zeros( - (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), - dtype=inputs.dtype, - ) - ), + initial_state, use_qk_l2norm_in_kernel=True, ) if cache is not None: - cache[0] = new_state - out = self.norm(out.reshape(-1, out.shape[-1]), z.reshape(-1, z.shape[-1])) + cache[1] = new_state + out = self.norm(out, z) return self.out_proj(out.reshape(B, S, -1)) @@ -332,6 +322,7 @@ def __init__(self, args: ModelArgs): intermediate_size = args.moe_intermediate_size shared_expert_intermediate_size = args.shared_expert_intermediate_size + self.norm_topk_prob = args.norm_topk_prob self.num_experts = num_experts = args.num_experts self.top_k = args.num_experts_per_tok @@ -351,6 +342,8 @@ def __call__( k = self.top_k inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) scores = mx.take_along_axis(gates, inds, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) y = self.switch_mlp(x, inds) y = (y * scores[..., None]).sum(axis=-2) @@ -370,8 +363,8 @@ def __init__(self, args: ModelArgs, layer_idx: int): else: self.self_attn = Qwen3NextAttention(args) - self.input_layernorm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = Qwen3NextRMSNorm( + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) if (layer_idx not in args.mlp_only_layers) and ( @@ -404,7 +397,7 @@ def __init__(self, args: ModelArgs): Qwen3NextDecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers) ] - self.norm = Qwen3NextRMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.fa_idx = args.full_attention_interval - 1 def __call__( @@ -473,7 +466,7 @@ def sanitize(self, weights): norm_keys = ( ".input_layernorm.weight", ".post_attention_layernorm.weight", - ".norm.weight", + "model.norm.weight", ".q_norm.weight", ".k_norm.weight", ) From bcf76acf9eaed47f04ad478a7001fa9f2db1aacf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 11 Sep 2025 19:48:46 -0700 Subject: [PATCH 38/44] fixes --- mlx_lm/models/qwen3_next.py | 43 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index e6c656c6c..c095aae1f 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -44,7 +44,6 @@ class ModelArgs(BaseModelArgs): full_attention_interval: int = 4 -@mx.compile def recurrent_gated_delta_rule( query: mx.array, key: mx.array, @@ -60,25 +59,33 @@ def recurrent_gated_delta_rule( query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) - if beta.ndim == 4: - beta = beta.squeeze(-1) - if g.ndim == 4: - g = g.squeeze(-1) + input_type = query.dtype Dv = value.shape[-1] query = inv_scale * query - v_beta, k_beta = value * beta[..., None], key * beta[..., None] - out = mx.zeros((B, S, H, Dv), dtype=value.dtype) + out = mx.zeros((B, H, S, Dv), dtype=value.dtype) - state = initial_state - - for t in range(S): - state = state * mx.exp(g[:, t])[..., None, None] + mx.matmul( - k_beta[:, t][..., None], v_beta[:, t][..., None, :] - ) - out[:, t] = mx.einsum("bhd,bhdv->bhv", query[:, t], state) + query, key, value, beta, g = [ + x.swapaxes(1, 2).astype(mx.float32) for x in (query, key, value, beta, g) + ] + state = initial_state + g = mx.exp(g) + + for i in range(S): + q_t = query[:, :, i][..., None] + k_t = key[:, :, i][..., None] + v_t = value[:, :, i] + g_t = g[:, :, i][..., None, None] + beta_t = beta[:, :, i][..., None] + + state = state * g_t + kv_mem = (state * k_t).sum(axis=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t * delta[..., None, :] + out[:, :, i] = (state * q_t).sum(axis=-2) + out = out.swapaxes(1, 2).astype(input_type) return out, state @@ -147,10 +154,9 @@ def __call__( q_proj_output = self.q_proj(x) queries, gate = mx.split( - q_proj_output.reshape(B, L, self.num_attention_heads, -1, 2), 2, axis=-1 + q_proj_output.reshape(B, L, self.num_attention_heads, -1), 2, axis=-1 ) - queries = queries.squeeze(-1) - gate = gate.squeeze(-1).reshape(B, L, -1) + gate = gate.reshape(B, L, -1) keys, values = self.k_proj(x), self.v_proj(x) @@ -285,7 +291,8 @@ def __call__( ) ] - beta, g = mx.sigmoid(b), -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) + beta = mx.sigmoid(b) + g = -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) if self.num_v_heads > self.num_k_heads: f = self.num_v_heads // self.num_k_heads From ef346c32da550bff1a72fafb7ebdb9d04cfc4a3e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 12 Sep 2025 10:39:22 +0200 Subject: [PATCH 39/44] move f to innit --- mlx_lm/models/qwen3_next.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index c095aae1f..5c3127675 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -205,6 +205,11 @@ def __init__(self, config: ModelArgs): self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads + if self.num_v_heads % self.num_k_heads != 0: + raise ValueError( + f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" + ) + self.repeat_factor = self.num_v_heads // self.num_k_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_norm_epsilon = config.rms_norm_eps @@ -294,9 +299,8 @@ def __call__( beta = mx.sigmoid(b) g = -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) - if self.num_v_heads > self.num_k_heads: - f = self.num_v_heads // self.num_k_heads - q, k = mx.repeat(q, f, 2), mx.repeat(k, f, 2) + if self.repeat_factor > 1: + q, k = mx.repeat(q, self.repeat_factor, 2), mx.repeat(k, self.repeat_factor, 2) if cache is not None and cache[1] is not None: initial_state = cache[1] From 7b792c9877915979e1e8b2e626b3b1fe86442a80 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 12 Sep 2025 11:00:10 +0200 Subject: [PATCH 40/44] optimized recurrent_gated_delta_rule --- mlx_lm/models/qwen3_next.py | 42 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 5c3127675..675c39574 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -52,39 +52,37 @@ def recurrent_gated_delta_rule( beta: mx.array, initial_state: mx.array, use_qk_l2norm_in_kernel: bool = False, -): +) -> Tuple[mx.array, mx.array]: B, S, H, Dk = key.shape inv_scale = Dk**-0.5 if use_qk_l2norm_in_kernel: query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) - input_type = query.dtype - Dv = value.shape[-1] query = inv_scale * query - out = mx.zeros((B, H, S, Dv), dtype=value.dtype) - query, key, value, beta, g = [ x.swapaxes(1, 2).astype(mx.float32) for x in (query, key, value, beta, g) ] - state = initial_state g = mx.exp(g) - - for i in range(S): - q_t = query[:, :, i][..., None] - k_t = key[:, :, i][..., None] - v_t = value[:, :, i] - g_t = g[:, :, i][..., None, None] - beta_t = beta[:, :, i][..., None] - + q_splits = mx.split(query, S, axis=2) + k_splits = mx.split(key, S, axis=2) + v_splits = mx.split(value, S, axis=2) + g_splits = mx.split(g, S, axis=2) + beta_splits = mx.split(beta, S, axis=2) + for i, (q_t, k_t, v_t, g_t, beta_t) in enumerate(zip(q_splits, k_splits, v_splits, g_splits, beta_splits)): + q_t = q_t.squeeze(2)[..., None] + k_t = k_t.squeeze(2)[..., None] + v_t = v_t.squeeze(2) + g_t = g_t.squeeze(2)[..., None, None] + beta_t = beta_t.squeeze(2)[..., None] state = state * g_t - kv_mem = (state * k_t).sum(axis=-2) + kv_mem = mx.einsum('bhkv,bhk->bhv', state, k_t.squeeze(-1)) delta = (v_t - kv_mem) * beta_t - state = state + k_t * delta[..., None, :] - out[:, :, i] = (state * q_t).sum(axis=-2) + state = state + mx.einsum('bhk,bhv->bhkv', k_t.squeeze(-1), delta) + out[:, :, i] = mx.einsum('bhkv,bhk->bhv', state, q_t.squeeze(-1)) out = out.swapaxes(1, 2).astype(input_type) return out, state @@ -238,7 +236,7 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) - def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + def fix_query_key_value_ordering(self, mixed_qkvz: mx.array, mixed_ba: mx.array) -> mx.array: nq, nk, nv, dv = ( self.num_k_heads, self.head_k_dim, @@ -265,7 +263,7 @@ def __call__( inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, - ): + ) -> mx.array: B, S, _ = inputs.shape q, k, v, z, b, a = self.fix_query_key_value_ordering( self.in_proj_qkvz(inputs), self.in_proj_ba(inputs) @@ -346,7 +344,7 @@ def __init__(self, args: ModelArgs): def __call__( self, x: mx.array, - ): + ) -> mx.array: gates = self.gate(x) gates = mx.softmax(gates, axis=-1, precise=True) @@ -415,7 +413,7 @@ def __call__( self, inputs: mx.array, cache: Optional[Any] = None, - ): + ) -> mx.array: hidden_states = self.embed_tokens(inputs) if cache is None: @@ -442,7 +440,7 @@ def __call__( self, inputs: mx.array, cache: Optional[Any] = None, - ): + ) -> mx.array: out = self.model(inputs, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) From a0685b1ec6cc66eed30719dc800014be428c3f34 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 12 Sep 2025 11:48:47 +0200 Subject: [PATCH 41/44] optmize and shorten recurrent_gated_delta_rule a lot + moving g = mx.exp(g) up to fix gibberish output --- mlx_lm/models/qwen3_next.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 675c39574..2202db394 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -59,32 +59,20 @@ def recurrent_gated_delta_rule( query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) input_type = query.dtype - Dv = value.shape[-1] query = inv_scale * query - out = mx.zeros((B, H, S, Dv), dtype=value.dtype) + out = mx.zeros((B, H, S, value.shape[-1]), dtype=value.dtype) + g = mx.exp(g) query, key, value, beta, g = [ x.swapaxes(1, 2).astype(mx.float32) for x in (query, key, value, beta, g) ] state = initial_state - g = mx.exp(g) - q_splits = mx.split(query, S, axis=2) - k_splits = mx.split(key, S, axis=2) - v_splits = mx.split(value, S, axis=2) - g_splits = mx.split(g, S, axis=2) - beta_splits = mx.split(beta, S, axis=2) - for i, (q_t, k_t, v_t, g_t, beta_t) in enumerate(zip(q_splits, k_splits, v_splits, g_splits, beta_splits)): - q_t = q_t.squeeze(2)[..., None] - k_t = k_t.squeeze(2)[..., None] - v_t = v_t.squeeze(2) - g_t = g_t.squeeze(2)[..., None, None] - beta_t = beta_t.squeeze(2)[..., None] - state = state * g_t - kv_mem = mx.einsum('bhkv,bhk->bhv', state, k_t.squeeze(-1)) - delta = (v_t - kv_mem) * beta_t - state = state + mx.einsum('bhk,bhv->bhkv', k_t.squeeze(-1), delta) - out[:, :, i] = mx.einsum('bhkv,bhk->bhv', state, q_t.squeeze(-1)) - out = out.swapaxes(1, 2).astype(input_type) - return out, state + for i in range(S): + state *= g[:, :, i, None, None] + kv_mem = mx.einsum('bhkv,bhk->bhv', state, key[:, :, i]) + delta = (value[:, :, i] - kv_mem) * beta[:, :, i, None] + state += mx.einsum('bhk,bhv->bhkv', key[:, :, i], delta) + out[:, :, i] = mx.einsum('bhkv,bhk->bhv', state, query[:, :, i]) + return out.swapaxes(1, 2).astype(input_type), state class Qwen3NextRMSNormGated(nn.Module): From ca24475f8b4ce5ac8b889598f869491d345030e3 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 12 Sep 2025 14:39:42 +0200 Subject: [PATCH 42/44] make train better --- mlx_lm/tuner/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index 9c0b40bd4..82988998a 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -135,7 +135,7 @@ def to_lora(layer): keys = {"self_attn.q_proj", "self_attn.v_proj"} if model.model_type in ["mixtral", "phimoe"]: keys.add("block_sparse_moe.gate") - if model.model_type == "qwen2_moe": + if model.model_type in ["qwen2_moe", "qwen3_next"]: keys.add("mlp.gate") keys.add("mlp.shared_expert_gate") if model.model_type in ["olmoe", "qwen3_moe", "dots1", "Klear"]: @@ -148,6 +148,15 @@ def to_lora(layer): keys.add("feed_forward.gate_proj") keys.add("feed_forward.up_proj") keys.add("feed_forward.down_proj") + elif model.model_type == "qwen3_next": + keys.add("linear_attn.in_proj_qkvz") + keys.add("linear_attn.out_proj") + keys.add("linear_attn.in_proj_ba") + keys.add("linear_attn.dt_bias") + keys.add("self_attn.q_proj") + keys.add("self_attn.k_proj") + keys.add("self_attn.v_proj") + keys.add("self_attn.o_proj") elif model.model_type == "gpt_bigcode": keys = {"attn.c_attn"} elif model.model_type == "gpt2": @@ -190,10 +199,6 @@ def to_lora(layer): keys = {"attn.attention.q_proj", "attn.attention.v_proj"} elif model.model_type == "bailing_moe": keys = {"attention.query_key_value", "attention.dense"} - elif model.model_type == "qwen3_next": - keys.add("self_attn.in_proj_qkvz") - keys.add("self_attn.in_proj_ba") - keys.add("self_attn.out_proj") elif model.model_type == "nemotron_h": keys.add("mixer.in_proj") keys.add("mixer.out_proj") From 8a9809a7f9bc1d3c51a7a1b3e77680669fa1a52d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 12 Sep 2025 11:57:22 -0700 Subject: [PATCH 43/44] nits --- mlx_lm/models/qwen3_moe.py | 2 +- mlx_lm/models/qwen3_next.py | 86 +++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index eea1f42c6..d94758a7a 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -127,7 +127,7 @@ def __call__( gates = mx.softmax(gates, axis=-1, precise=True) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] scores = mx.take_along_axis(gates, inds, axis=-1) if self.norm_topk_prob: scores /= mx.sum(scores, axis=-1, keepdims=True) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index 2202db394..aba99d9be 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -48,31 +48,43 @@ def recurrent_gated_delta_rule( query: mx.array, key: mx.array, value: mx.array, - g: mx.array, - beta: mx.array, - initial_state: mx.array, - use_qk_l2norm_in_kernel: bool = False, + A_log: mx.array, + dt_bias: mx.array, + a: mx.array, + b: mx.array, + initial_state: Optional[mx.array], + use_qk_l2norm_in_kernel: bool, ) -> Tuple[mx.array, mx.array]: - B, S, H, Dk = key.shape + + input_type = query.dtype + B, S, Hk, Dk = key.shape + Hv, Dv = value.shape[-2:] + beta = mx.sigmoid(b) + g = -mx.exp(A_log) * nn.softplus(a + dt_bias) + initial_state = mx.zeros((B, Hv, Dk, Dv), dtype=input_type) + inv_scale = Dk**-0.5 if use_qk_l2norm_in_kernel: query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) - input_type = query.dtype query = inv_scale * query - out = mx.zeros((B, H, S, value.shape[-1]), dtype=value.dtype) + if (repeat_factor := (Hv // Hk)) > 1: + query = mx.repeat(query, repeat_factor, 2) + key = mx.repeat(key, repeat_factor, 2) + + out = mx.zeros((B, S, Hv, Dv), dtype=value.dtype) g = mx.exp(g) query, key, value, beta, g = [ - x.swapaxes(1, 2).astype(mx.float32) for x in (query, key, value, beta, g) + x.astype(mx.float32) for x in (query, key, value, beta, g) ] state = initial_state for i in range(S): - state *= g[:, :, i, None, None] - kv_mem = mx.einsum('bhkv,bhk->bhv', state, key[:, :, i]) - delta = (value[:, :, i] - kv_mem) * beta[:, :, i, None] - state += mx.einsum('bhk,bhv->bhkv', key[:, :, i], delta) - out[:, :, i] = mx.einsum('bhkv,bhk->bhv', state, query[:, :, i]) - return out.swapaxes(1, 2).astype(input_type), state + state = state * g[:, i, :, None, None] + kv_mem = mx.einsum("bhkv,bhk->bhv", state, key[:, i, :]) + delta = (value[:, i, :] - kv_mem) * beta[:, i, :, None] + state += mx.einsum("bhk,bhv->bhkv", key[:, i, :], delta) + out[:, i, :] = mx.einsum("bhkv,bhk->bhv", state, query[:, i, :]) + return out.astype(input_type), state class Qwen3NextRMSNormGated(nn.Module): @@ -195,7 +207,6 @@ def __init__(self, config: ModelArgs): raise ValueError( f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})" ) - self.repeat_factor = self.num_v_heads // self.num_k_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_norm_epsilon = config.rms_norm_eps @@ -224,26 +235,26 @@ def __init__(self, config: ModelArgs): self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) - def fix_query_key_value_ordering(self, mixed_qkvz: mx.array, mixed_ba: mx.array) -> mx.array: - nq, nk, nv, dv = ( + def fix_query_key_value_ordering( + self, mixed_qkvz: mx.array, mixed_ba: mx.array + ) -> mx.array: + nk, dn, nv, dv = ( self.num_k_heads, self.head_k_dim, self.num_v_heads, self.head_v_dim, ) - mixed_qkvz = mixed_qkvz.reshape( - *mixed_qkvz.shape[:-1], nq, 2 * nk + 2 * nv * dv // nq - ) - mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nq, 2 * nv // nq) - q, k, v, z = mx.split(mixed_qkvz, [nk, 2 * nk, 2 * nk + nv // nq * dv], axis=-1) - b, a = mx.split(mixed_ba, [nv // nq], axis=-1) + mixed_qkvz = mixed_qkvz.reshape(*mixed_qkvz.shape[:-1], nk, -1) + mixed_ba = mixed_ba.reshape(*mixed_ba.shape[:-1], nk, -1) + q, k, v, z = mx.split(mixed_qkvz, [dn, 2 * dn, 2 * dn + nv // nk * dv], axis=-1) + b, a = mx.split(mixed_ba, [nv // nk], axis=-1) return ( q, k, - v.reshape(v.shape[0], v.shape[1], -1, dv), - z.reshape(z.shape[0], z.shape[1], -1, dv), - b.reshape(b.shape[0], b.shape[1], nv), - a.reshape(a.shape[0], a.shape[1], nv), + v.reshape(*v.shape[:2], -1, dv), + z.reshape(*z.shape[:2], -1, dv), + b.reshape(*b.shape[:2], nv), + a.reshape(*a.shape[:2], nv), ) def __call__( @@ -282,26 +293,19 @@ def __call__( ) ] - beta = mx.sigmoid(b) - g = -mx.exp(self.A_log) * nn.softplus(a + self.dt_bias) - - if self.repeat_factor > 1: - q, k = mx.repeat(q, self.repeat_factor, 2), mx.repeat(k, self.repeat_factor, 2) - - if cache is not None and cache[1] is not None: + if cache is not None: initial_state = cache[1] else: - initial_state = mx.zeros( - (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), - dtype=inputs.dtype, - ) + initial_state = None out, new_state = recurrent_gated_delta_rule( q, k, v, - g, - beta, + self.A_log, + self.dt_bias, + a, + b, initial_state, use_qk_l2norm_in_kernel=True, ) @@ -337,7 +341,7 @@ def __call__( gates = mx.softmax(gates, axis=-1, precise=True) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] scores = mx.take_along_axis(gates, inds, axis=-1) if self.norm_topk_prob: scores = scores / scores.sum(axis=-1, keepdims=True) From 16ca09aa94145331df8f24a2a38125764681778c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 12 Sep 2025 17:04:04 -0700 Subject: [PATCH 44/44] nits + fix --- mlx_lm/models/qwen3_next.py | 73 ++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/mlx_lm/models/qwen3_next.py b/mlx_lm/models/qwen3_next.py index aba99d9be..19ea7c2d2 100644 --- a/mlx_lm/models/qwen3_next.py +++ b/mlx_lm/models/qwen3_next.py @@ -1,6 +1,7 @@ # Copyright © 2025 Apple Inc. from dataclasses import dataclass +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx @@ -44,47 +45,49 @@ class ModelArgs(BaseModelArgs): full_attention_interval: int = 4 +@partial(mx.compile, shapeless=True) +def compute_g(A_log, a, dt_bias): + return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias)).astype( + A_log.dtype + ) + + def recurrent_gated_delta_rule( query: mx.array, key: mx.array, value: mx.array, - A_log: mx.array, - dt_bias: mx.array, a: mx.array, b: mx.array, - initial_state: Optional[mx.array], - use_qk_l2norm_in_kernel: bool, + A_log: mx.array, + dt_bias: mx.array, + state: mx.array, + use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[mx.array, mx.array]: - - input_type = query.dtype B, S, Hk, Dk = key.shape - Hv, Dv = value.shape[-2:] - beta = mx.sigmoid(b) - g = -mx.exp(A_log) * nn.softplus(a + dt_bias) - initial_state = mx.zeros((B, Hv, Dk, Dv), dtype=input_type) - + Hv, Dv = value.shape[2:] inv_scale = Dk**-0.5 if use_qk_l2norm_in_kernel: - query = inv_scale * mx.fast.rms_norm(query, None, 1e-6) + query = (inv_scale**2) * mx.fast.rms_norm(query, None, 1e-6) key = inv_scale * mx.fast.rms_norm(key, None, 1e-6) - query = inv_scale * query - if (repeat_factor := (Hv // Hk)) > 1: + else: + query = inv_scale * query + + input_type = query.dtype + if (repeat_factor := Hv // Hk) > 1: query = mx.repeat(query, repeat_factor, 2) key = mx.repeat(key, repeat_factor, 2) - out = mx.zeros((B, S, Hv, Dv), dtype=value.dtype) - g = mx.exp(g) - query, key, value, beta, g = [ - x.astype(mx.float32) for x in (query, key, value, beta, g) - ] - state = initial_state + beta = mx.sigmoid(b) + g = compute_g(A_log, a, dt_bias) + outs = [] for i in range(S): - state = state * g[:, i, :, None, None] - kv_mem = mx.einsum("bhkv,bhk->bhv", state, key[:, i, :]) - delta = (value[:, i, :] - kv_mem) * beta[:, i, :, None] - state += mx.einsum("bhk,bhv->bhkv", key[:, i, :], delta) - out[:, i, :] = mx.einsum("bhkv,bhk->bhv", state, query[:, i, :]) - return out.astype(input_type), state + kv_mem = (state * key[:, i, :, :, None]).sum(axis=-2) + delta = (value[:, i] - kv_mem) * beta[:, i, :, None] + state += key[:, i, :, :, None] * delta[..., None, :] + out = (state * query[:, i, :, :, None]).sum(axis=-2) + outs.append(out) + out = mx.stack(outs, axis=1).astype(input_type) + return out, state class Qwen3NextRMSNormGated(nn.Module): @@ -293,25 +296,29 @@ def __call__( ) ] - if cache is not None: - initial_state = cache[1] + if cache is not None and cache[1] is not None: + state = cache[1] else: - initial_state = None + state = mx.zeros( + (B, self.num_v_heads, self.head_k_dim, self.head_v_dim), + dtype=inputs.dtype, + ) out, new_state = recurrent_gated_delta_rule( q, k, v, - self.A_log, - self.dt_bias, a, b, - initial_state, + self.A_log, + self.dt_bias, + state, use_qk_l2norm_in_kernel=True, ) if cache is not None: - cache[1] = new_state + cache[1] = state + out = self.norm(out, z) return self.out_proj(out.reshape(B, S, -1))