Skip to content

Commit a8ce7ea

Browse files
support HunYuanMoEV1ForCausalLM
Co-authored-by: quinnrong <[email protected]> Signed-off-by: aiyiwang <[email protected]>
1 parent 2582683 commit a8ce7ea

File tree

6 files changed

+1781
-3
lines changed

6 files changed

+1781
-3
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ Specified using `--task generate`.
388388
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
389389
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
390390
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |
391+
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ |
391392

392393
!!! note
393394
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def check_available_online(
259259
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
260260
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
261261
trust_remote_code=True),
262+
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct"),
262263
# [Encoder-decoder]
263264
"BartModel": _HfExamplesInfo("facebook/bart-base"),
264265
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,40 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
532532
return cache
533533

534534

535+
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
536+
"""RotaryEmbedding extended with Dynamic NTK alpha.
537+
538+
Based on the original RotaryEmbedding implementation.
539+
"""
540+
541+
def __init__(
542+
self,
543+
head_size: int,
544+
rotary_dim: int,
545+
max_position_embeddings: int,
546+
base: int,
547+
is_neox_style: bool,
548+
scaling_alpha: float,
549+
dtype: torch.dtype,
550+
) -> None:
551+
self.scaling_alpha = scaling_alpha
552+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
553+
is_neox_style, dtype)
554+
555+
def _compute_cos_sin_cache(self) -> torch.Tensor:
556+
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
557+
max_len = self.max_position_embeddings
558+
base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim-2))
559+
inv_freq = self._compute_inv_freq(base)
560+
t = torch.arange(max_len, dtype=torch.float)
561+
562+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
563+
cos = freqs.cos()
564+
sin = freqs.sin()
565+
cache = torch.cat((cos, sin), dim=-1)
566+
return cache
567+
568+
535569
# Inverse dim formula to find dim based on number of rotations
536570
def _yarn_find_correction_dim(num_rotations: int,
537571
dim: int,
@@ -1810,9 +1844,15 @@ def get_rope(
18101844
mixed_b)
18111845
elif scaling_type == "dynamic":
18121846
scaling_factor = rope_scaling["factor"]
1813-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
1814-
head_size, rotary_dim, max_position, base, is_neox_style,
1815-
scaling_factor, dtype)
1847+
scaling_alpha = rope_scaling["alpha"]
1848+
if scaling_alpha:
1849+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
1850+
head_size, rotary_dim, max_position, base, is_neox_style,
1851+
scaling_alpha, dtype)
1852+
else:
1853+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
1854+
head_size, rotary_dim, max_position, base, is_neox_style,
1855+
scaling_factor, dtype)
18161856
elif scaling_type == "yarn":
18171857
scaling_factor = rope_scaling["factor"]
18181858
original_max_position = rope_scaling[

0 commit comments

Comments
 (0)