Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
USE_XFORMERS_OPS = None

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and envs.VLLM_ROCM_USE_AITER_MLA

Check failure on line 33 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:33:81: E501 Line too long (149 > 80)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")

Check failure on line 37 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/attention/layer.py:37:13: G004 Logging statement uses f-string

def check_xformers_availability():
global USE_XFORMERS_OPS
Expand Down Expand Up @@ -536,12 +536,12 @@
key,
value,
kv_cache,
attn_metadata,
output=output,

Check failure on line 540 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM101)

vllm/attention/layer.py:539:13: SIM101 Multiple `isinstance` calls for expression, merge into a single call
output_scale=output_scale,
positions=positions)
else:
assert positions is None, f"positions must be None {positions=}"

Check failure on line 544 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:544:81: E501 Line too long (142 > 80)
self.impl.forward(self,
query,
key,
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True
ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: bool = True

def get_default_cache_root():
return os.getenv(
Expand Down Expand Up @@ -1237,15 +1238,15 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),

Check failure on line 1241 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1241:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1245 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1245:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),

Check failure on line 1249 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1249:81: E501 Line too long (94 > 80)
# Use AITER Triton fused FP8 per-token group quant + FP8 batched GEMM
"VLLM_ROCM_USE_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FP8_BMM", "1"))),
Expand All @@ -1271,6 +1272,11 @@
# Apply preshuffling for mxfp4 scales for ROCm backend
"ROCM_TRITON_MOE_PRESHUFFLE_SCALES":
lambda: bool(int(os.getenv("ROCM_TRITON_MOE_PRESHUFFLE_SCALES", "1"))),

# Use AITER Triton fused gate_up_proj + moe_gate and fused gemm reduce + silu_mul + fp8 group quant
"VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS", "1"))),

}

# --8<-- [end:env-vars-definition]
Expand Down
91 changes: 85 additions & 6 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@
maybe_prefix)
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.logger import init_logger
logger = init_logger(__name__)

from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT

VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and envs.VLLM_ROCM_USE_AITER_MLA
Expand All @@ -81,17 +84,75 @@

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD:
from aiter.ops.triton.fused_mul_add import fused_mul_add

if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS:
from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16
from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def rocm_aiter_triton_fused_shared_expert_impl(
hidden_states_shared: torch.Tensor,
hidden_states_shared_scale: torch.Tensor,
weight_gate_up: torch.Tensor,
weight_scale_gate_up: torch.Tensor,
hidden_states_moe_gate: torch.Tensor,
weight_moe_gate: torch.Tensor,
bias_shared: torch.Tensor,
bias_moe_gate: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate,
bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True)
if shared_output.dim() == 3:
(shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=router_logits, group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
else:
(shared_output_q, shared_output_s), _ = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=None, group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
return shared_output_q, shared_output_s, router_logits

def rocm_aiter_triton_fused_shared_expert_fake(
hidden_states_shared: torch.Tensor,
hidden_states_shared_scale: torch.Tensor,
weight_gate_up: torch.Tensor,
weight_scale_gate_up: torch.Tensor,
hidden_states_moe_gate: torch.Tensor,
weight_moe_gate: torch.Tensor,
bias_shared: torch.Tensor,
bias_moe_gate: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_shared.shape[0]
N = weight_gate_up.shape[0]
N_moe = weight_moe_gate.shape[0]
device = hidden_states_shared.device
assert N % 2 == 0
N_half = N // 2
assert N_half == 256, f"{weight_gate_up.shape}"
assert N_half == N_moe, f"{weight_moe_gate.shape}"
shared_output_q = torch.empty((M, N_half), dtype=rocm_aiter_fp8_dtype, device=device)
shared_output_s = torch.empty((M, (N_half + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device)
router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device)
return shared_output_q, shared_output_s, router_logits

direct_register_custom_op(
op_name="rocm_aiter_triton_fused_shared_expert",
op_func=rocm_aiter_triton_fused_shared_expert_impl,
mutates_args=[],
fake_impl=rocm_aiter_triton_fused_shared_expert_fake,
dispatch_key=current_platform.dispatch_key,
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS = False

VLLM_ROCM_USE_AITER_MLA = envs.VLLM_ROCM_USE_AITER_MLA
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MLA=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS=}")

class DeepseekV2MLP(nn.Module):

Expand Down Expand Up @@ -168,7 +229,11 @@ def __init__(
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
e_score_correction_bias = self.gate.e_score_correction_bias
if is_rocm_aiter_moe_enabled():
e_score_correction_bias = self.gate.e_score_correction_bias.to(torch.bfloat16)
else:
e_score_correction_bias = None
self.gate.e_score_correction_bias = None

# Load balancing settings.
Expand Down Expand Up @@ -200,7 +265,7 @@ def __init__(
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
e_score_correction_bias=e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)

Expand All @@ -225,10 +290,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS and self.n_shared_experts is not None:
hidden_states_shared, hidden_states_shared_scale = hidden_states_shared
shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
hidden_states_shared = hidden_states_shared,
hidden_states_shared_scale = hidden_states_shared_scale,
weight_gate_up = self.shared_experts.gate_up_proj.weight,
weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv,
hidden_states_moe_gate = hidden_states,
weight_moe_gate = self.gate.weight,
bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None,
bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None,
)
shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s)
else:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
Expand Down
Loading