Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ th {
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2`. | ✅︎ | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ |
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ def __init__(
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
use_presharded_weights: bool = False,
):
super().__init__()
if params_dtype is None:
Expand Down Expand Up @@ -866,6 +867,7 @@ def __init__(
self.e_score_correction_bias = e_score_correction_bias
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self.use_presharded_weights = use_presharded_weights

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -1086,10 +1088,11 @@ def _load_w13(self,
tp_rank: int,
load_full: bool = False):

should_skip_sharding = self.use_presharded_weights or load_full
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
if not should_skip_sharding:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
Expand All @@ -1110,11 +1113,12 @@ def _load_w2(self,
tp_rank: int,
load_full: bool = False):

should_skip_sharding = self.use_presharded_weights or load_full
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
if not should_skip_sharding:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
Expand Down
94 changes: 94 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math

import torch

from .base import RotaryEmbedding
from .common import (yarn_find_correction_range, yarn_get_mscale,
yarn_linear_ramp_mask)


class Grok1ScalingRotaryEmbedding(RotaryEmbedding):
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extra_method: str = "yarn_log",
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extra_method = extra_method
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
if self.extra_method in ["original"]:
inv_freq = inv_freq_extrapolation
elif self.extra_method in ["yarn", "yarn_linear"]:
inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) +
inv_freq_extrapolation * inv_freq_mask)
elif self.extra_method == "yarn_log":
inv_freq = torch.exp(
torch.log(inv_freq_extrapolation) * inv_freq_mask +
torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask))
elif self.extra_method == "theta_scale":
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
theta_scale_exponent = self.base**(
math.log(self.max_position_embeddings * self.scaling_factor /
(2 * math.pi)) /
math.log(self.max_position_embeddings / (2 * math.pi)))
inv_freq = torch.tensor(
1.0 / (theta_scale_exponent**(exponents / self.rotary_dim)),
dtype=torch.float32,
)
else:
raise ValueError(
f"Unknown extrapolation method: {self.extra_method}")
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# cos = freqs.cos() * self.mscale
# sin = freqs.sin() * self.mscale
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
Loading
Loading