Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's `Mamba v1`
inclusionAI's `Bailing MoE e.g. Ling-family`, `Bailing MoE Linear e.g. Ling-Linear-family`,
Klear team - Kuaishou Technology's `Klear`, AI21 Lab's `Jamba` IBM's `Granite MoE`,
Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, Nikity's `Lille130m`,
Alibaba Qwen's `Qwen3Next`, Tele-AI's `TeleChat3`, and Allenai's `OLMoE` and `Olmo 3`;
Alibaba Qwen's `Qwen3Next` and `Qwen3.5 MoE`, Tele-AI's `TeleChat3`, and Allenai's `OLMoE` and `Olmo 3`;
Helped add support for the following model architectures:
Alibaba Qwen's `Qwen3 & Qwen3MoE)`; Added support for the following training algorithms:
`Full Weight Fine-Tuning`, and the `Muon` optimizer;
Expand All @@ -26,4 +26,4 @@ Added support for the following other features:
MoonshotAI's `Kimi-Linear`, LiquidAI's `LFM2` and `LFM2 MoE`,
Google DeepMind's `Gemma 3`, TII's `Falcon H1` and InterLM's `InternLM 2.5`.
- Ivan Fioravanti: Added support for the following architectures:
ServiceNow-AI's `Apriel 1.5`, Tencent's `Hunyuan Dense V1` and `Hunyuan MoE V1`.
ServiceNow-AI's `Apriel 1.5`, Tencent's `Hunyuan Dense V1` and `Hunyuan MoE V1`.
62 changes: 62 additions & 0 deletions mlx_lm/models/qwen3_5_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright © 2025 Apple Inc.

from dataclasses import dataclass
from typing import Optional

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten

from . import qwen3_5_moe_text
from .base import BaseModelArgs


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
text_config: dict

@classmethod
def from_dict(cls, params):
if "text_config" not in params:
return cls(model_type=params["model_type"], text_config=params)
return super().from_dict(params)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.language_model = qwen3_5_moe_text.Model(
qwen3_5_moe_text.ModelArgs.from_dict(args.text_config)
)

def __call__(
self,
inputs: mx.array,
cache=None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
return self.language_model(
inputs, cache=cache, input_embeddings=input_embeddings
)

def sanitize(self, weights):
weights = tree_unflatten(list(weights.items()))
weights.pop("vision_tower", None)
weights = dict(tree_flatten(weights))

sanitized = {}
for key, value in weights.items():
if not key.startswith("language_model."):
key = "language_model." + key
sanitized[key] = value
return sanitized

@property
def layers(self):
return self.language_model.model.layers

def make_cache(self):
return self.language_model.make_cache()
Loading