Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ Specified using `--task generate`.
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B | `tencent/Hunyuan-7B-Instruct`, `tencent/Hunyuan-7B-Pretrain`etc. | | | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down Expand Up @@ -387,7 +389,7 @@ Specified using `--task generate`.
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |

Expand Down
22 changes: 20 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
# via isoduration
async-timeout==5.0.1
# via
# aiohttp
# redis
attrs==24.2.0
# via
# aiohttp
Expand Down Expand Up @@ -141,6 +145,11 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
exceptiongroup==1.3.0
# via
# anyio
# hypothesis
# pytest
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
Expand Down Expand Up @@ -690,7 +699,6 @@ setuptools==77.0.3
# via
# mamba-ssm
# pytablewriter
# torch
# triton
shellingham==1.5.4
# via typer
Expand Down Expand Up @@ -753,8 +761,13 @@ tokenizers==0.21.1
# via
# -r requirements/test.in
# transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.2.1
# via schemathesis
# via
# black
# pytest
# schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.7.0+cu128
Expand Down Expand Up @@ -828,13 +841,18 @@ types-python-dateutil==2.9.0.20241206
# via arrow
typing-extensions==4.12.2
# via
# anyio
# black
# exceptiongroup
# huggingface-hub
# librosa
# mistral-common
# mteb
# multidict
# pqdm
# pydantic
# pydantic-core
# rich
# torch
# typer
# typing-inspection
Expand Down
4 changes: 3 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def check_available_online(
trust_remote_code=True),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst",
min_transformers_version="4.53"),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to register dense model here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add HunYuanDenseV1ForCausalLM

Note:
We are currently working on some HF model governance. The architecture corresponding to the previously open Dense model is called HunYuanForCausalLM. The subsequent Dense model will be called HunYuanDenseV1ForCausalLM. If you want to run the previous model, you need to change the architecture. This PR does not include adaptation of the previous model.

"HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
Expand Down Expand Up @@ -489,4 +491,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
47 changes: 44 additions & 3 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,41 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK alpha.

Based on the original RotaryEmbedding implementation.
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_cos_sin_cache(self) -> torch.Tensor:
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
Expand Down Expand Up @@ -1810,9 +1845,15 @@ def get_rope(
mixed_b)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
scaling_alpha = rope_scaling["alpha"]
if scaling_alpha:
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_alpha, dtype)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
Expand Down
Loading
Loading