diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index d59b74782be2..07c2b1f2cdef 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -91,6 +91,8 @@ def _rocm_aiter_fused_moe_impl( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, + moe_sorting_dispatch_policy: int = 0, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -118,6 +120,8 @@ def _rocm_aiter_fused_moe_impl( intermediate_pad=intermediate_pad, bias1=bias1, bias2=bias2, + moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, + moe_buf=moe_buf, ) @@ -141,7 +145,11 @@ def _rocm_aiter_fused_moe_fake( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, + moe_sorting_dispatch_policy: int = 0, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: + if moe_buf is not None: + return torch.empty_like(moe_buf) if output_dtype is not None: return torch.empty_like(hidden_states, dtype=output_dtype) return torch.empty_like(hidden_states) @@ -1256,7 +1264,7 @@ def register_ops_once() -> None: direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=_rocm_aiter_fused_moe_impl, - mutates_args=[], + mutates_args=["moe_buf"], fake_impl=_rocm_aiter_fused_moe_fake, dispatch_key=current_platform.dispatch_key, ) @@ -1546,6 +1554,8 @@ def fused_moe( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, + moe_sorting_dispatch_policy: int = 0, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -1567,6 +1577,8 @@ def fused_moe( intermediate_pad, bias1, bias2, + moe_sorting_dispatch_policy, + moe_buf, ) @staticmethod diff --git a/vllm/envs.py b/vllm/envs.py index d2af9e64d66c..5a97a1ed660e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -122,6 +122,7 @@ VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True + VLLM_ROCM_AITER_MOE_DISPATCH_POLICY: int = 0 VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -1039,6 +1040,9 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1") ), + "VLLM_ROCM_AITER_MOE_DISPATCH_POLICY": lambda: int( + os.getenv("VLLM_ROCM_AITER_MOE_DISPATCH_POLICY", "0") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index d24bda101ffa..62a7661b13a2 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -5,6 +5,7 @@ import torch +import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -195,6 +196,8 @@ def rocm_aiter_fused_experts( a1q_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + moe_sorting_dispatch_policy: int = 0, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: """ROCm AITER fused MoE expert computation.""" if quant_config is None: @@ -309,6 +312,8 @@ def rocm_aiter_fused_experts( intermediate_pad=intermediate_pad, bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None, bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None, + moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, + moe_buf=moe_buf, ) @@ -422,7 +427,7 @@ def apply( else: num_local_tokens = None - result = rocm_aiter_fused_experts( + rocm_aiter_fused_experts( hidden_states=hidden_states, w1=w1, w2=w2, @@ -436,5 +441,6 @@ def apply( a1q_scale=a1q_scale, num_local_tokens=num_local_tokens, output_dtype=output.dtype, + moe_sorting_dispatch_policy=envs.VLLM_ROCM_AITER_MOE_DISPATCH_POLICY, + moe_buf=output, ) - output.copy_(result)