Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def _rocm_aiter_fused_moe_impl(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
moe_sorting_dispatch_policy: int = 0,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
Expand Down Expand Up @@ -118,6 +120,8 @@ def _rocm_aiter_fused_moe_impl(
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
moe_sorting_dispatch_policy=moe_sorting_dispatch_policy,
moe_buf=moe_buf,
)


Expand All @@ -141,7 +145,11 @@ def _rocm_aiter_fused_moe_fake(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
moe_sorting_dispatch_policy: int = 0,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
if moe_buf is not None:
return torch.empty_like(moe_buf)
if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype)
return torch.empty_like(hidden_states)
Expand Down Expand Up @@ -1256,7 +1264,7 @@ def register_ops_once() -> None:
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=_rocm_aiter_fused_moe_impl,
mutates_args=[],
mutates_args=["moe_buf"],
fake_impl=_rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand Down Expand Up @@ -1546,6 +1554,8 @@ def fused_moe(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
moe_sorting_dispatch_policy: int = 0,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
Expand All @@ -1567,6 +1577,8 @@ def fused_moe(
intermediate_pad,
bias1,
bias2,
moe_sorting_dispatch_policy,
moe_buf,
)

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
VLLM_ROCM_AITER_MOE_DISPATCH_POLICY: int = 0
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -1039,6 +1040,9 @@ def _get_or_set_default() -> str:
"VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1")
),
"VLLM_ROCM_AITER_MOE_DISPATCH_POLICY": lambda: int(
os.getenv("VLLM_ROCM_AITER_MOE_DISPATCH_POLICY", "0")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
Expand Down Expand Up @@ -195,6 +196,8 @@ def rocm_aiter_fused_experts(
a1q_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
moe_sorting_dispatch_policy: int = 0,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
"""ROCm AITER fused MoE expert computation."""
if quant_config is None:
Expand Down Expand Up @@ -309,6 +312,8 @@ def rocm_aiter_fused_experts(
intermediate_pad=intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
moe_sorting_dispatch_policy=moe_sorting_dispatch_policy,
moe_buf=moe_buf,
)


Expand Down Expand Up @@ -422,7 +427,7 @@ def apply(
else:
num_local_tokens = None

result = rocm_aiter_fused_experts(
rocm_aiter_fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
Expand All @@ -436,5 +441,6 @@ def apply(
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,
moe_sorting_dispatch_policy=envs.VLLM_ROCM_AITER_MOE_DISPATCH_POLICY,
moe_buf=output,
)
output.copy_(result)
Loading