Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
USE_XFORMERS_OPS = None

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and envs.VLLM_ROCM_USE_AITER_MLA

Check failure on line 33 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:33:81: E501 Line too long (149 > 80)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")

Check failure on line 37 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/attention/layer.py:37:13: G004 Logging statement uses f-string

def check_xformers_availability():
global USE_XFORMERS_OPS
Expand Down Expand Up @@ -537,13 +537,13 @@
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,

Check failure on line 542 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM101)

vllm/attention/layer.py:540:13: SIM101 Multiple `isinstance` calls for expression, merge into a single call
positions=positions)
else:
assert positions is None, f"positions must be None {positions=}"
self.impl.forward(self,

Check failure on line 546 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:546:81: E501 Line too long (142 > 80)
query,
key,
value,
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True
ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_AITER_TRITON_MLA: bool = False

def get_default_cache_root():
Expand Down Expand Up @@ -1238,15 +1239,15 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),

Check failure on line 1242 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1242:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1246 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1246:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),

Check failure on line 1250 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1250:81: E501 Line too long (94 > 80)
# Use AITER Triton fused FP8 per-token group quant + FP8 batched GEMM
"VLLM_ROCM_USE_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FP8_BMM", "1"))),
Expand All @@ -1272,6 +1273,11 @@
# Apply preshuffling for mxfp4 scales for ROCm backend
"ROCM_TRITON_MOE_PRESHUFFLE_SCALES":
lambda: bool(int(os.getenv("ROCM_TRITON_MOE_PRESHUFFLE_SCALES", "1"))),

# Use AITER Triton fused gate_up_proj + moe_gate and fused gemm reduce + silu_mul + fp8 group quant
"VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS", "1"))),


# Use AITER Triton MLA
"VLLM_ROCM_USE_AITER_TRITON_MLA":
Expand Down
91 changes: 85 additions & 6 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@
maybe_prefix)
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.logger import init_logger
logger = init_logger(__name__)

from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT

VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and envs.VLLM_ROCM_USE_AITER_MLA
Expand All @@ -81,17 +84,75 @@

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD:
from aiter.ops.triton.fused_mul_add import fused_mul_add

if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS:
from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import fused_gemm_a8w8_blockscale_a16w16
from aiter.ops.triton.fused_fp8_quant import fused_reduce_act_mul_fp8_group_quant
import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def rocm_aiter_triton_fused_shared_expert_impl(
hidden_states_shared: torch.Tensor,
hidden_states_shared_scale: torch.Tensor,
weight_gate_up: torch.Tensor,
weight_scale_gate_up: torch.Tensor,
hidden_states_moe_gate: torch.Tensor,
weight_moe_gate: torch.Tensor,
bias_shared: torch.Tensor,
bias_moe_gate: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
shared_output, router_logits = fused_gemm_a8w8_blockscale_a16w16(hidden_states_shared, weight_gate_up, hidden_states_shared_scale, weight_scale_gate_up, hidden_states_moe_gate, weight_moe_gate,
bias_fp8=bias_shared, bias_bf16=bias_moe_gate, dtype=hidden_states_moe_gate.dtype, skip_reduce=True)
if shared_output.dim() == 3:
(shared_output_q, shared_output_s), router_logits = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=router_logits, group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
else:
(shared_output_q, shared_output_s), _ = fused_reduce_act_mul_fp8_group_quant(shared_output, activation="silu", x2=None, group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype)
return shared_output_q, shared_output_s, router_logits

def rocm_aiter_triton_fused_shared_expert_fake(
hidden_states_shared: torch.Tensor,
hidden_states_shared_scale: torch.Tensor,
weight_gate_up: torch.Tensor,
weight_scale_gate_up: torch.Tensor,
hidden_states_moe_gate: torch.Tensor,
weight_moe_gate: torch.Tensor,
bias_shared: torch.Tensor,
bias_moe_gate: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M = hidden_states_shared.shape[0]
N = weight_gate_up.shape[0]
N_moe = weight_moe_gate.shape[0]
device = hidden_states_shared.device
assert N % 2 == 0
N_half = N // 2
assert N_half == 256, f"{weight_gate_up.shape}"
assert N_half == N_moe, f"{weight_moe_gate.shape}"
shared_output_q = torch.empty((M, N_half), dtype=rocm_aiter_fp8_dtype, device=device)
shared_output_s = torch.empty((M, (N_half + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size), dtype=torch.float32, device=device)
router_logits = torch.empty((M, N_moe), dtype=hidden_states_moe_gate.dtype, device=device)
return shared_output_q, shared_output_s, router_logits

direct_register_custom_op(
op_name="rocm_aiter_triton_fused_shared_expert",
op_func=rocm_aiter_triton_fused_shared_expert_impl,
mutates_args=[],
fake_impl=rocm_aiter_triton_fused_shared_expert_fake,
dispatch_key=current_platform.dispatch_key,
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS = False

VLLM_ROCM_USE_AITER_MLA = envs.VLLM_ROCM_USE_AITER_MLA
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MLA=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS=}")

class DeepseekV2MLP(nn.Module):

Expand Down Expand Up @@ -168,7 +229,11 @@ def __init__(
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
e_score_correction_bias = self.gate.e_score_correction_bias
if is_rocm_aiter_moe_enabled():
e_score_correction_bias = self.gate.e_score_correction_bias.to(torch.bfloat16)
else:
e_score_correction_bias = None
self.gate.e_score_correction_bias = None

# Load balancing settings.
Expand Down Expand Up @@ -200,7 +265,7 @@ def __init__(
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
e_score_correction_bias=e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)

Expand All @@ -225,10 +290,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_SHARED_EXPERTS and self.n_shared_experts is not None:
hidden_states_shared, hidden_states_shared_scale = hidden_states_shared
shared_output_q, shared_output_s, router_logits = torch.ops.vllm.rocm_aiter_triton_fused_shared_expert(
hidden_states_shared = hidden_states_shared,
hidden_states_shared_scale = hidden_states_shared_scale,
weight_gate_up = self.shared_experts.gate_up_proj.weight,
weight_scale_gate_up = self.shared_experts.gate_up_proj.weight_scale_inv,
hidden_states_moe_gate = hidden_states,
weight_moe_gate = self.gate.weight,
bias_shared = self.shared_experts.gate_up_proj.bias if not self.shared_experts.gate_up_proj.skip_bias_add else None,
bias_moe_gate = self.gate.bias if not self.gate.skip_bias_add else None,
)
shared_output, _ = self.shared_experts.down_proj(shared_output_q, x_quant_scales = shared_output_s)
else:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states_shared)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

if VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD and hidden_states.dtype != torch.float16 and shared_output is not None:
final_hidden_states = self.experts(hidden_states=hidden_states,
Expand Down
Loading