From 620ad559742de691604d49f5b8f53762770aefc9 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Fri, 29 May 2026 14:43:50 +0000 Subject: [PATCH 1/2] feat: Add support for JetBrains' Mellum v2 code generation model [Mellum](https://www.jetbrains.com/mellum/) v2 is an update to JetBrains' open-weights code generation model that is built on a Mixture-of-Experts architecture. --- docs/models/supported_models.md | 1 + tests/models/registry.py | 1 + vllm/model_executor/models/mellum.py | 254 ++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/mellum.py | 7 + 7 files changed, 267 insertions(+) create mode 100644 vllm/model_executor/models/mellum.py create mode 100644 vllm/transformers_utils/configs/mellum.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0654a59caf13..4612b4c423f7 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -437,6 +437,7 @@ th { | `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | +| `MellumForCausalLM` | Mellum 2 | `JetBrains/Mellum2-12B-A2.5B-Base`, etc. | | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | | `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | | ✅︎ | | `MiMoV2ForCausalLM` | MiMoV2Pro | `XiaomiMiMo/MiMo-V2.5-Pro`, etc. | | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 3ef6997621da..cfc3eb819c9a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -522,6 +522,7 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "MellumForCausalLM": _HfExamplesInfo("JetBrains/Mellum2-12B-A2.5B-Base"), "Qwen3NextForCausalLM": _HfExamplesInfo( "Qwen/Qwen3-Next-80B-A3B-Instruct", extras={"tiny-random": "tiny-random/qwen3-next-moe"}, diff --git a/vllm/model_executor/models/mellum.py b/vllm/model_executor/models/mellum.py new file mode 100644 index 000000000000..6ffee9011dc2 --- /dev/null +++ b/vllm/model_executor/models/mellum.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + +from .qwen3_moe import ( + Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeMLP, + Qwen3MoeModel, + Qwen3MoeSparseMoeBlock, +) +from .utils import PPMissingLayer, extract_layer_index, maybe_prefix + + +class MellumAttention(Qwen3MoeAttention): + """ + Differences from `Qwen3MoeAttention`: + - Supports `per_layer_sliding_window` for `Attention`. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict[str, Any], + max_position_embeddings: int = 8192, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Any | None = None, + quant_config: Any | None = None, + prefix: str = "", + dual_chunk_attention_config: dict[str, Any] | None = None, + per_layer_sliding_window: int | None = None, + ) -> None: + nn.Module.__init__(self) + + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{prefix}.attn", + **( + { + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } + if dual_chunk_attention_config + else {} + ), + ) + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + +class MellumDecoderLayer(Qwen3MoeDecoderLayer): + """ + Differences from `Qwen3MoeDecoderLayer`: + - Supports interleaved SWA and per-layer RoPE scaling. + """ + + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + if layer_type == "sliding_attention": + sliding_window = getattr(config, "sliding_window", None) + else: + sliding_window = None + rope_parameters = config.rope_parameters[layer_type] + + self.self_attn = MellumAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_parameters=rope_parameters, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, + per_layer_sliding_window=sliding_window, + ) + + if config.mlp_layer_types[layer_idx] == "sparse": + self.mlp = Qwen3MoeSparseMoeBlock( + vllm_config=vllm_config, prefix=f"{prefix}.mlp" + ) + else: + self.mlp = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + +@support_torch_compile +class MellumModel(Qwen3MoeModel): + """ + Differences from `Qwen3MoeModel`: + - Uses `MellumDecoderLayer`. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=MellumDecoderLayer, + ) + + +class MellumForCausalLM(Qwen3MoeForCausalLM): + """ + Differences from `Qwen3MoeForCausalLM`: + - Uses `MellumModel`. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_text_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + if "dense" in getattr(config, "mlp_layer_types", []): + self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] + self.model = MellumModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + self.expert_weights = [] + + self.moe_layers = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3MoeDecoderLayer) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No MoE layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 94472d27e1cd..d96ceeb4b506 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -159,6 +159,7 @@ "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), + "MellumForCausalLM": ("mellum", "MellumForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f940739a96c3..8339c183c0fd 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -116,6 +116,7 @@ def __getitem__(self, key): RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) mlp_speculator="MLPSpeculatorConfig", medusa="MedusaConfig", + mellum="MellumConfig", midashenglm="MiDashengLMConfig", moondream3="Moondream3Config", eagle="EAGLEConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5998e61dfd88..71f7723e4c80 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -49,6 +49,7 @@ "LagunaConfig": "vllm.transformers_utils.configs.laguna", "Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe", "MedusaConfig": "vllm.transformers_utils.configs.medusa", + "MellumConfig": "vllm.transformers_utils.configs.mellum", "MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm", "MLPSpeculatorConfig": "vllm.transformers_utils.configs.mlp_speculator", "Moondream3Config": "vllm.transformers_utils.configs.moondream3", @@ -117,6 +118,7 @@ "LagunaConfig", "Lfm2MoeConfig", "MedusaConfig", + "MellumConfig", "MiDashengLMConfig", "MLPSpeculatorConfig", "Moondream3Config", diff --git a/vllm/transformers_utils/configs/mellum.py b/vllm/transformers_utils/configs/mellum.py new file mode 100644 index 000000000000..2bed53394b25 --- /dev/null +++ b/vllm/transformers_utils/configs/mellum.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import Qwen3MoeConfig + + +class MellumConfig(Qwen3MoeConfig): + model_type = "mellum" From f3a4e7df34fa29753b2099dea8d38800dfb60fcc Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Mon, 1 Jun 2026 09:50:42 +0000 Subject: [PATCH 2/2] Lints Signed-off-by: Madeesh Kannan --- vllm/model_executor/models/mellum.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/mellum.py b/vllm/model_executor/models/mellum.py index 6ffee9011dc2..bdbf0df7fd14 100644 --- a/vllm/model_executor/models/mellum.py +++ b/vllm/model_executor/models/mellum.py @@ -121,7 +121,6 @@ class MellumDecoderLayer(Qwen3MoeDecoderLayer): - Supports interleaved SWA and per-layer RoPE scaling. """ - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: nn.Module.__init__(self)