Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 383 additions & 0 deletions mlx_lm/models/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,383 @@
# Copyright © 2026 Apple Inc.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten

from .base import (
BaseModelArgs,
create_attention_mask,
create_ssm_mask,
)
from .cache import ArraysCache, KVCache
from .gated_delta import gated_delta_update
from .qwen3_next import Qwen3NextAttention as Attention
from .qwen3_next import Qwen3NextMLP as MLP
from .qwen3_next import Qwen3NextRMSNormGated as RMSNormGated
from .qwen3_next import Qwen3NextSparseMoeBlock as SparseMoeBlock


@dataclass
class TextModelArgs(BaseModelArgs):
model_type: str = ""
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
rms_norm_eps: float = 1e-6
vocab_size: int = 151936
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
linear_num_value_heads: int = 64
linear_num_key_heads: int = 16
linear_key_head_dim: int = 192
linear_value_head_dim: int = 128
linear_conv_kernel_dim: int = 4
tie_word_embeddings: bool = False
attention_bias: bool = False
head_dim: Optional[int] = None
full_attention_interval: int = 4

# MoE fields (optional, for Qwen3_5MoeForConditionalGeneration)
num_experts: int = 0
num_experts_per_tok: int = 0
decoder_sparse_step: int = 1
shared_expert_intermediate_size: int = 0
moe_intermediate_size: int = 0
norm_topk_prob: bool = True

# Rope parameters
rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field(
default_factory=lambda: {
"type": "default",
"mrope_section": [11, 11, 10],
"rope_theta": 100000,
"partial_rotary_factor": 0.25,
}
)

# Derived from rope_parameters (set in __post_init__)
partial_rotary_factor: float = 0.25
rope_theta: float = 100000.0
rope_scaling: Optional[Dict[str, Union[float, str]]] = None

def __post_init__(self):
if self.head_dim is None:
self.head_dim = self.hidden_size // self.num_attention_heads

if self.rope_parameters:
if (
"type" not in self.rope_parameters
and "rope_type" in self.rope_parameters
):
self.rope_parameters["type"] = self.rope_parameters.pop("rope_type")

self.partial_rotary_factor = self.rope_parameters.get(
"partial_rotary_factor", 0.25
)
self.rope_theta = self.rope_parameters.get("rope_theta", 100000.0)
self.rope_scaling = self.rope_parameters


class GatedDeltaNet(nn.Module):
def __init__(self, config: TextModelArgs):
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
if self.num_v_heads % self.num_k_heads != 0:
raise ValueError(
f"num_v_heads ({self.num_v_heads}) must be divisible by num_k_heads ({self.num_k_heads})"
)

self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_norm_epsilon = config.rms_norm_eps

self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=False,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
padding=0,
)

self.in_proj_qkv = nn.Linear(
self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False
)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)

self.dt_bias = mx.ones(self.num_v_heads)

A = mx.random.uniform(low=0, high=16, shape=(self.num_v_heads,))
self.A_log = mx.log(A)

self.norm = RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)

self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)

def __call__(
self,
inputs: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, S, _ = inputs.shape

qkv = self.in_proj_qkv(inputs)
z = self.in_proj_z(inputs).reshape(B, S, self.num_v_heads, self.head_v_dim)
b = self.in_proj_b(inputs)
a = self.in_proj_a(inputs)

if cache is not None and cache[0] is not None:
conv_state = cache[0]
else:
conv_state = mx.zeros(
(B, self.conv_kernel_size - 1, self.conv_dim),
dtype=inputs.dtype,
)

if mask is not None:
qkv = mx.where(mask[..., None], qkv, 0)
conv_input = mx.concatenate([conv_state, qkv], axis=1)
if cache is not None:
cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :]
conv_out = nn.silu(self.conv1d(conv_input))

q, k, v = [
t.reshape(B, S, h, d)
for t, h, d in zip(
mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1),
[self.num_k_heads, self.num_k_heads, self.num_v_heads],
[self.head_k_dim, self.head_k_dim, self.head_v_dim],
)
]

state = cache[1] if cache else None
inv_scale = k.shape[-1] ** -0.5
q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6)
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)

out, state = gated_delta_update(
q,
k,
v,
a,
b,
self.A_log,
self.dt_bias,
state,
mask,
use_kernel=not self.training,
)

if cache is not None:
cache[1] = state

out = self.norm(out, z)
return self.out_proj(out.reshape(B, S, -1))


class DecoderLayer(nn.Module):
def __init__(self, args: TextModelArgs, layer_idx: int):
super().__init__()
self.is_linear = (layer_idx + 1) % args.full_attention_interval != 0
if self.is_linear:
self.linear_attn = GatedDeltaNet(args)
else:
self.self_attn = Attention(args)

self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)

if args.num_experts > 0:
self.mlp = SparseMoeBlock(args)
else:
self.mlp = MLP(args.hidden_size, args.intermediate_size)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
if self.is_linear:
r = self.linear_attn(self.input_layernorm(x), mask, cache)
else:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
out = h + self.mlp(self.post_attention_layernorm(h))
return out


class Qwen3_5TextModel(nn.Module):
def __init__(self, args: TextModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, layer_idx=i) for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ssm_idx = 0
self.fa_idx = args.full_attention_interval - 1

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
if input_embeddings is not None:
hidden_states = input_embeddings
else:
hidden_states = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)

fa_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
ssm_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx])

for layer, c in zip(self.layers, cache):
mask = ssm_mask if layer.is_linear else fa_mask
hidden_states = layer(hidden_states, mask=mask, cache=c)

return self.norm(hidden_states)


class TextModel(nn.Module):
def __init__(self, args: TextModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen3_5TextModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache, input_embeddings=input_embeddings)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out

@property
def layers(self):
return self.model.layers

def make_cache(self):
return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers]

def sanitize(self, weights):
has_mtp_weights = any("mtp." in k for k in weights)
has_unsanitized_conv1d = any(
"conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items()
)
should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d
weights = {k: v for k, v in weights.items() if "mtp." not in k}

if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)

norm_keys = (
".input_layernorm.weight",
".post_attention_layernorm.weight",
"model.norm.weight",
".q_norm.weight",
".k_norm.weight",
)
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if should_shift_norm_weights and any(k.endswith(sfx) for sfx in norm_keys):
if v.ndim == 1:
weights[k] = v + 1.0
return weights

@property
def quant_predicate(self):
if self.args.num_experts <= 0:
return None

def predicate(path, _):
if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"):
return {"group_size": 64, "bits": 8}
return True

return predicate


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
text_config: dict

@classmethod
def from_dict(cls, params):
if "text_config" not in params:
return cls(model_type=params["model_type"], text_config=params)
return super().from_dict(params)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.language_model = TextModel(TextModelArgs.from_dict(args.text_config))

def __call__(
self,
inputs: mx.array,
cache=None,
input_embeddings: Optional[mx.array] = None,
):
return self.language_model(
inputs, cache=cache, input_embeddings=input_embeddings
)

def sanitize(self, weights):
weights = tree_unflatten(list(weights.items()))
weights = dict(tree_flatten(weights))

sanitized = {}
for key, value in weights.items():
if key.startswith("model.visual"):
continue
if key.startswith("model.language_model"):
key = key.replace("model.language_model", "language_model.model")
elif key.startswith("language_model."):
pass
else:
key = "language_model." + key
sanitized[key] = value
return self.language_model.sanitize(sanitized)

@property
def layers(self):
return self.language_model.model.layers

def make_cache(self):
return self.language_model.make_cache()

@property
def quant_predicate(self):
return self.language_model.quant_predicate
Loading
Loading