Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ th {
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def check_available_online(
trust_remote_code=True),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct",
trust_remote_code=True),
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124",
trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@
make_layers)


def _is_moe(config: PretrainedConfig) -> bool:
num_experts = getattr(config, "num_experts", None)
if isinstance(num_experts, int):
return num_experts > 1
if isinstance(num_experts, list) and num_experts:
# Ensure all elements are integers before calling max.
if all(isinstance(e, int) for e in num_experts):
return max(num_experts) > 1
else:
return False
return False


def _get_cla_factor(config: PretrainedConfig) -> int:
if not getattr(config, "use_cla", False):
return 1
Expand Down Expand Up @@ -140,8 +153,8 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
if hasattr(config, "head_dim"):

if hasattr(config, "head_dim") and config.head_dim:
self.head_dim = config.head_dim
elif hasattr(config, "attention_head_dim"):
self.head_dim = config.attention_head_dim
Expand Down Expand Up @@ -490,12 +503,23 @@ def __init__(
else:
raise RuntimeError(f"Unsupported attention type: {attention_type}")

self.mlp = HunYuanSparseMoeBlock(
config=config,
quant_config=quant_config,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
)
if _is_moe(config):
self.mlp = HunYuanSparseMoeBlock(
config=config,
quant_config=quant_config,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = HunYuanMLP(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -642,15 +666,17 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
return torch.concat((q, k, v))

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
else:
return []

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config)
Expand Down Expand Up @@ -815,7 +841,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loaded_params


class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA):
class HunYuanV1Base(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -901,3 +927,11 @@ def load_weights(self, weights: Iterable[tuple[str,

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()


class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
pass


class HunYuanMoEV1ForCausalLM(HunYuanV1Base):
pass
3 changes: 2 additions & 1 deletion vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
"GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501
"GritLM": ("gritlm", "GritLM"),
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
Expand Down