Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ th {
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
| `MellumForCausalLM` | Mellum 2 | `JetBrains/Mellum2-12B-A2.5B-Base`, etc. | | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | | ✅︎ |
| `MiMoV2ForCausalLM` | MiMoV2Pro | `XiaomiMiMo/MiMo-V2.5-Pro`, etc. | | ✅︎ |
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"MellumForCausalLM": _HfExamplesInfo("JetBrains/Mellum2-12B-A2.5B-Base"),
"Qwen3NextForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
extras={"tiny-random": "tiny-random/qwen3-next-moe"},
Expand Down
253 changes: 253 additions & 0 deletions vllm/model_executor/models/mellum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

from torch import nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead

from .qwen3_moe import (
Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP,
Qwen3MoeModel,
Qwen3MoeSparseMoeBlock,
)
from .utils import PPMissingLayer, extract_layer_index, maybe_prefix


class MellumAttention(Qwen3MoeAttention):
"""
Differences from `Qwen3MoeAttention`:
- Supports `per_layer_sliding_window` for `Attention`.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Any | None = None,
quant_config: Any | None = None,
prefix: str = "",
dual_chunk_attention_config: dict[str, Any] | None = None,
per_layer_sliding_window: int | None = None,
) -> None:
nn.Module.__init__(self)

self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{prefix}.attn",
**(
{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
}
if dual_chunk_attention_config
else {}
),
)

self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)


class MellumDecoderLayer(Qwen3MoeDecoderLayer):
"""
Differences from `Qwen3MoeDecoderLayer`:
- Supports interleaved SWA and per-layer RoPE scaling.
"""

def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
nn.Module.__init__(self)

config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)

layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
if layer_type == "sliding_attention":
sliding_window = getattr(config, "sliding_window", None)
else:
sliding_window = None
rope_parameters = config.rope_parameters[layer_type]

self.self_attn = MellumAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_parameters=rope_parameters,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
per_layer_sliding_window=sliding_window,
)

if config.mlp_layer_types[layer_idx] == "sparse":
self.mlp = Qwen3MoeSparseMoeBlock(
vllm_config=vllm_config, prefix=f"{prefix}.mlp"
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)

self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)


@support_torch_compile
class MellumModel(Qwen3MoeModel):
"""
Differences from `Qwen3MoeModel`:
- Uses `MellumDecoderLayer`.
"""

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=MellumDecoderLayer,
)


class MellumForCausalLM(Qwen3MoeForCausalLM):
"""
Differences from `Qwen3MoeForCausalLM`:
- Uses `MellumModel`.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
if "dense" in getattr(config, "mlp_layer_types", []):
self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
self.model = MellumModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

self.expert_weights = []

self.moe_layers = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue

assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)

if example_layer is None:
raise RuntimeError("No MoE layer found in the model.layers.")

self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MellumForCausalLM": ("mellum", "MellumForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __getitem__(self, key):
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
mlp_speculator="MLPSpeculatorConfig",
medusa="MedusaConfig",
mellum="MellumConfig",
midashenglm="MiDashengLMConfig",
moondream3="Moondream3Config",
eagle="EAGLEConfig",
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"LagunaConfig": "vllm.transformers_utils.configs.laguna",
"Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe",
"MedusaConfig": "vllm.transformers_utils.configs.medusa",
"MellumConfig": "vllm.transformers_utils.configs.mellum",
"MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm",
"MLPSpeculatorConfig": "vllm.transformers_utils.configs.mlp_speculator",
"Moondream3Config": "vllm.transformers_utils.configs.moondream3",
Expand Down Expand Up @@ -117,6 +118,7 @@
"LagunaConfig",
"Lfm2MoeConfig",
"MedusaConfig",
"MellumConfig",
"MiDashengLMConfig",
"MLPSpeculatorConfig",
"Moondream3Config",
Expand Down
7 changes: 7 additions & 0 deletions vllm/transformers_utils/configs/mellum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import Qwen3MoeConfig


class MellumConfig(Qwen3MoeConfig):
model_type = "mellum"
Loading