Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def get_model_params(config):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"GlmMoeDsaForCausalLM",
"Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"NemotronHForCausalLM",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def check_available_online(
"zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0",
),
"GlmMoeDsaForCausalLM": _HfExamplesInfo(
"zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
"bigcode/starcoder",
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
)

if model_arch == "DeepseekV32ForCausalLM":
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
from vllm.platforms import current_platform

capability = current_platform.get_device_capability()
Expand Down
2 changes: 1 addition & 1 deletion vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def compute_hash(self) -> str:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def __init__(
qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=True,
is_neox_style=not getattr(config, "indexer_rope_interleave", True),
)
self.indexer = Indexer(
vllm_config,
Expand Down Expand Up @@ -1499,6 +1499,10 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass


class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM):
pass


# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
def get_spec_layer_idx_from_weight_name(
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
"Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
"GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def is_deepseek_mla(self) -> bool:
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"glm_moe_dsa",
"glm4_moe_lite",
"glm4_moe_lite_mtp",
"kimi_k2",
Expand Down
Loading