diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 633e23eea33e..ee224e6922fb 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
-| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
+| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched4 | batched | int8,fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index 014df1fa111f..c27cf2468ede 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -6,6 +6,8 @@
"""
import functools
+import importlib
+import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@@ -20,6 +22,7 @@
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
@@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface."""
# clear the cache before every test
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_moe_enabled,
- )
+ # Force reload aiter_ops to pick up the new environment variables.
+ if "rocm_aiter_ops" in sys.modules:
+ importlib.reload(rocm_aiter_ops)
- is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
-
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py
index 41419553aa83..9121284de85b 100644
--- a/tests/model_executor/test_enabled_custom_ops.py
+++ b/tests/model_executor/test_enabled_custom_ops.py
@@ -4,6 +4,7 @@
import pytest
import torch
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
@@ -15,9 +16,6 @@
dispatch_topk_func,
vllm_topk_softmax,
)
-from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_moe_enabled,
-)
from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
@@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()
-@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
-def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
- monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
- topk_func = dispatch_topk_func()
- is_rocm_aiter_moe_enabled.cache_clear()
- if current_platform.is_rocm() and int(use_rocm_aiter):
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- rocm_aiter_topk_softmax,
- )
+@pytest.mark.parametrize(
+ "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
+)
+def test_topk_dispatch(use_rocm_aiter: bool):
+ topk_func = dispatch_topk_func(use_rocm_aiter)
- assert topk_func == rocm_aiter_topk_softmax
+ if current_platform.is_rocm() and use_rocm_aiter:
+ assert topk_func == rocm_aiter_ops.topk_softmax
else:
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
-@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
+@pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
)
def test_rms_norm_dispatch(
- add_residual: bool,
- dtype: torch.dtype,
- use_rocm_aiter: str,
- use_rocm_aiter_norm: str,
- monkeypatch,
+ add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
):
- monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
- monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
- rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
+ rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
should_use_rocm_aiter = (
current_platform.is_rocm()
- and int(use_rocm_aiter)
- and int(use_rocm_aiter_norm)
+ and use_rocm_aiter
and dtype in RMS_NORM_SUPPORTED_DTYPES
)
if add_residual and should_use_rocm_aiter:
- assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
+ assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
- assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
+ assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
new file mode 100644
index 000000000000..9a4b5f3399be
--- /dev/null
+++ b/vllm/_aiter_ops.py
@@ -0,0 +1,941 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import functools
+from collections.abc import Callable
+
+import torch
+
+import vllm.envs as envs
+from vllm.platforms import current_platform
+from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
+
+
+def is_aiter_found() -> bool:
+ from importlib.util import find_spec
+
+ return find_spec("aiter") is not None
+
+
+# `find_spec` is not torch.compile compatible.
+# In cases where aiter availability might have
+# been checked in forward passes that are torch compiled.
+# we keep this global outside to not cause torch compile breaks.
+IS_AITER_FOUND = is_aiter_found()
+
+
+def if_aiter_supported(func: Callable) -> Callable:
+ """Decorator that only executes the function if
+ ROCm AITER package is supported on gfx9 archs.
+ """
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ # checks the platform, device arch and aiter library existance.
+
+ from vllm.platforms.rocm import on_gfx9
+
+ if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
+ return func(*args, **kwargs)
+ else:
+ # Return None or do nothing if not supported
+ return None
+
+ return wrapper
+
+
+def _rocm_aiter_fused_moe_impl(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weight: torch.Tensor,
+ topk_ids: torch.Tensor,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+ quant_method: int = 0,
+ doweight_stage1: bool = False,
+ w1_scale: torch.Tensor | None = None,
+ w2_scale: torch.Tensor | None = None,
+ a1_scale: torch.Tensor | None = None,
+ a2_scale: torch.Tensor | None = None,
+) -> torch.Tensor:
+ from aiter import ActivationType, QuantType
+ from aiter.fused_moe import fused_moe
+
+ activation = ActivationType(activation_method)
+ quant_type = QuantType(quant_method)
+
+ return fused_moe(
+ hidden_states,
+ w1,
+ w2,
+ topk_weight,
+ topk_ids,
+ expert_mask,
+ activation,
+ quant_type,
+ doweight_stage1,
+ w1_scale,
+ w2_scale,
+ a1_scale,
+ a2_scale,
+ )
+
+
+def _rocm_aiter_fused_moe_fake(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weight: torch.Tensor,
+ topk_ids: torch.Tensor,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+ quant_method: int = 0,
+ doweight_stage1: bool = False,
+ w1_scale: torch.Tensor | None = None,
+ w2_scale: torch.Tensor | None = None,
+ a1_scale: torch.Tensor | None = None,
+ a2_scale: torch.Tensor | None = None,
+) -> torch.Tensor:
+ return torch.empty_like(hidden_states)
+
+
+def _rocm_aiter_asm_moe_tkw1_impl(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ fc1_scale: torch.Tensor | None = None,
+ fc2_scale: torch.Tensor | None = None,
+ fc1_smooth_scale: torch.Tensor | None = None,
+ fc2_smooth_scale: torch.Tensor | None = None,
+ a16: bool = False,
+ per_tensor_quant_scale: torch.Tensor | None = None,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+) -> torch.Tensor:
+ from aiter import ActivationType
+ from aiter.fused_moe_bf16_asm import asm_moe_tkw1
+
+ activation = ActivationType(activation_method)
+
+ return asm_moe_tkw1(
+ hidden_states,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ fc1_scale=fc1_scale,
+ fc2_scale=fc2_scale,
+ fc1_smooth_scale=fc1_smooth_scale,
+ fc2_smooth_scale=fc2_smooth_scale,
+ a16=a16,
+ per_tensor_quant_scale=per_tensor_quant_scale,
+ expert_mask=expert_mask,
+ activation=activation,
+ )
+
+
+def _rocm_aiter_asm_moe_tkw1_fake(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ fc1_scale: torch.Tensor | None = None,
+ fc2_scale: torch.Tensor | None = None,
+ fc1_smooth_scale: torch.Tensor | None = None,
+ fc2_smooth_scale: torch.Tensor | None = None,
+ a16: bool = False,
+ per_tensor_quant_scale: torch.Tensor | None = None,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+) -> torch.Tensor:
+ return torch.empty_like(hidden_states)
+
+
+def _rocm_aiter_topk_softmax_impl(
+ topk_weights: torch.Tensor,
+ topk_indices: torch.Tensor,
+ token_expert_indices: torch.Tensor,
+ gating_output: torch.Tensor,
+ renormalize: bool,
+) -> None:
+ from aiter import topk_softmax
+
+ topk_softmax(
+ topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
+ )
+
+
+def _rocm_aiter_topk_softmax_fake(
+ topk_weights: torch.Tensor,
+ topk_indices: torch.Tensor,
+ token_expert_indices: torch.Tensor,
+ gating_output: torch.Tensor,
+ renormalize: bool,
+) -> None:
+ pass
+
+
+def _rocm_aiter_biased_grouped_topk_impl(
+ gating_output: torch.Tensor,
+ correction_bias: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ routed_scaling_factor: float = 1.0, # mul to topk_weights
+) -> None:
+ from aiter import biased_grouped_topk
+
+ biased_grouped_topk(
+ gating_output,
+ correction_bias,
+ topk_weights,
+ topk_ids,
+ num_expert_group,
+ topk_group,
+ need_renorm,
+ routed_scaling_factor,
+ )
+
+
+def _rocm_aiter_biased_grouped_topk_fake(
+ gating_output: torch.Tensor,
+ correction_bias: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ routed_scaling_factor: float = 1.0, # mul to topk_weights
+) -> None:
+ pass
+
+
+def _rocm_aiter_grouped_topk_impl(
+ gating_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0, # mul to topk_weights
+) -> None:
+ is_softmax = scoring_func == "softmax"
+ from aiter import grouped_topk
+
+ grouped_topk(
+ gating_output,
+ topk_weights,
+ topk_ids,
+ num_expert_group,
+ topk_group,
+ need_renorm,
+ is_softmax,
+ routed_scaling_factor,
+ )
+
+
+def _rocm_aiter_grouped_topk_fake(
+ gating_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0, # mul to topk_weights
+) -> None:
+ pass
+
+
+def _rocm_aiter_mla_decode_fwd_impl(
+ q: torch.Tensor,
+ kv_buffer: torch.Tensor,
+ o: torch.Tensor,
+ qo_indptr: torch.Tensor,
+ max_seqlen_qo: int,
+ kv_indptr: torch.Tensor | None = None,
+ kv_indices: torch.Tensor | None = None,
+ kv_last_page_lens: torch.Tensor | None = None,
+ sm_scale: float = 1.0,
+ logit_cap: float = 0.0,
+) -> None:
+ from aiter.mla import mla_decode_fwd
+
+ mla_decode_fwd(
+ q,
+ kv_buffer.view(-1, 1, 1, q.shape[-1]),
+ o,
+ qo_indptr,
+ kv_indptr,
+ kv_indices,
+ kv_last_page_lens,
+ max_seqlen_qo,
+ sm_scale=sm_scale,
+ logit_cap=logit_cap,
+ )
+
+
+def _rocm_aiter_mla_decode_fwd_fake(
+ q: torch.Tensor,
+ kv_buffer: torch.Tensor,
+ o: torch.Tensor,
+ qo_indptr: torch.Tensor,
+ max_seqlen_qo: int,
+ kv_indptr: torch.Tensor | None = None,
+ kv_indices: torch.Tensor | None = None,
+ kv_last_page_lens: torch.Tensor | None = None,
+ sm_scale: float = 1.0,
+ logit_cap: float = 0.0,
+) -> None:
+ pass
+
+
+def _rocm_aiter_gemm_w8a8_impl(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ from aiter import gemm_a8w8_CK
+
+ # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
+ # a to be [M, K]
+ # b to be [N, K]
+ # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
+ return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
+
+
+def _rocm_aiter_gemm_w8a8_fake(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ m = A.shape[0]
+ n = B.shape[0]
+ Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
+ return Y
+
+
+def _rocm_aiter_gemm_w8a8_blockscale_impl(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ output_dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ from aiter import gemm_a8w8_blockscale
+
+ return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
+
+
+def _rocm_aiter_gemm_w8a8_blockscale_fake(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ output_dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ m = A.shape[0]
+ n = B.shape[0]
+ Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
+ return Y
+
+
+def _rocm_aiter_rms_norm_impl(
+ x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
+) -> torch.Tensor:
+ from aiter import rms_norm
+
+ if x.dim() > 2:
+ x_original_shape = x.shape
+ x = x.reshape(-1, x_original_shape[-1])
+ x = rms_norm(x, weight, variance_epsilon)
+ return x.reshape(x_original_shape)
+
+ return rms_norm(x, weight, variance_epsilon)
+
+
+def _rocm_aiter_rms_norm_fake(
+ x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
+) -> torch.Tensor:
+ return torch.empty_like(x)
+
+
+def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ variance_epsilon: float,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ from aiter import rmsnorm2d_fwd_with_add
+
+ residual_out = torch.empty_like(residual)
+ output = torch.empty_like(x)
+ rmsnorm2d_fwd_with_add(
+ output, # output
+ x, # input
+ residual, # residual input
+ residual_out, # residual output
+ weight,
+ variance_epsilon,
+ )
+ return output, residual_out
+
+
+def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ variance_epsilon: float,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ return torch.empty_like(x), torch.empty_like(residual)
+
+
+# Global flag to ensure ops are registered only once
+_OPS_REGISTERED = False
+
+
+class rocm_aiter_ops:
+ _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
+ _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
+ _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
+ _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
+ _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
+ _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
+ _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
+ _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
+ _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
+ _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
+ _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
+ _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
+
+ @classmethod
+ @if_aiter_supported
+ def is_enabled(cls) -> bool:
+ """Verifies device specs and availability of aiter main env variable."""
+ return cls._AITER_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_linear_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._LINEAR_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_linear_fp8_enaled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
+
+ @classmethod
+ @if_aiter_supported
+ def is_rmsnorm_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_fused_moe_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._FMOE_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_fusion_moe_shared_experts_enabled(cls) -> bool:
+ return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_mla_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._MLA_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_mha_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._MHA_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_pa_attn_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_triton_unified_attn_enabled(cls) -> bool:
+ """ "Verifies device specs and availability of env variable."""
+ return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_fp8bmm_enabled(cls) -> bool:
+ return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
+
+ @classmethod
+ @if_aiter_supported
+ def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
+ return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
+
+ @classmethod
+ @if_aiter_supported
+ def is_triton_rotary_embed_enabled(cls) -> bool:
+ return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
+
+ @staticmethod
+ @if_aiter_supported
+ def register_ops_once() -> None:
+ global _OPS_REGISTERED
+ if not _OPS_REGISTERED:
+ tags = (
+ tuple()
+ if is_torch_equal_or_newer("2.7.0")
+ else (torch.Tag.needs_fixed_stride_order,)
+ )
+
+ # register all the custom ops here
+ direct_register_custom_op(
+ op_name="rocm_aiter_asm_moe_tkw1",
+ op_func=_rocm_aiter_asm_moe_tkw1_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_fused_moe",
+ op_func=_rocm_aiter_fused_moe_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_fused_moe_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_topk_softmax",
+ op_func=_rocm_aiter_topk_softmax_impl,
+ mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
+ fake_impl=_rocm_aiter_topk_softmax_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_biased_grouped_topk",
+ op_func=_rocm_aiter_biased_grouped_topk_impl,
+ mutates_args=["topk_weights", "topk_ids"],
+ fake_impl=_rocm_aiter_biased_grouped_topk_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_grouped_topk",
+ op_func=_rocm_aiter_grouped_topk_impl,
+ mutates_args=["topk_weights", "topk_ids"],
+ fake_impl=_rocm_aiter_grouped_topk_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_mla_decode_fwd",
+ op_func=_rocm_aiter_mla_decode_fwd_impl,
+ mutates_args=["o"],
+ fake_impl=_rocm_aiter_mla_decode_fwd_fake,
+ tags=tags,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_gemm_w8a8",
+ op_func=_rocm_aiter_gemm_w8a8_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_gemm_w8a8_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_gemm_w8a8_blockscale",
+ op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_rms_norm",
+ op_func=_rocm_aiter_rms_norm_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_rms_norm_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ direct_register_custom_op(
+ op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
+ op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
+ _OPS_REGISTERED = True
+
+ @staticmethod
+ def rms_norm2d_with_add(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ variance_epsilon: float,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
+ x, residual, weight, variance_epsilon
+ )
+
+ @staticmethod
+ def rms_norm(
+ x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
+
+ @staticmethod
+ def gemm_w8a8(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.float16,
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
+
+ @staticmethod
+ def gemm_w8a8_blockscale(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ block_size: list[int],
+ output_dtype: torch.dtype = torch.float16,
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
+ A, B, As, Bs, output_dtype
+ )
+
+ @staticmethod
+ def fused_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weight: torch.Tensor,
+ topk_ids: torch.Tensor,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+ quant_method: int = 0,
+ doweight_stage1: bool = False,
+ w1_scale: torch.Tensor | None = None,
+ w2_scale: torch.Tensor | None = None,
+ a1_scale: torch.Tensor | None = None,
+ a2_scale: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_fused_moe(
+ hidden_states,
+ w1,
+ w2,
+ topk_weight,
+ topk_ids,
+ expert_mask,
+ activation_method,
+ quant_method,
+ doweight_stage1,
+ w1_scale,
+ w2_scale,
+ a1_scale,
+ a2_scale,
+ )
+
+ @staticmethod
+ def asm_moe_tkw1(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ fc1_scale: torch.Tensor | None = None,
+ fc2_scale: torch.Tensor | None = None,
+ fc1_smooth_scale: torch.Tensor | None = None,
+ fc2_smooth_scale: torch.Tensor | None = None,
+ a16: bool = False,
+ per_tensor_quant_scale: torch.Tensor | None = None,
+ expert_mask: torch.Tensor | None = None,
+ activation_method: int = 0,
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
+ hidden_states,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ fc1_scale,
+ fc2_scale,
+ fc1_smooth_scale,
+ fc2_smooth_scale,
+ a16,
+ per_tensor_quant_scale,
+ expert_mask,
+ activation_method,
+ )
+
+ @staticmethod
+ def topk_softmax(
+ topk_weights: torch.Tensor,
+ topk_indices: torch.Tensor,
+ token_expert_indices: torch.Tensor,
+ gating_output: torch.Tensor,
+ renormalize: bool,
+ ) -> tuple[torch.Tensor, ...]:
+ torch.ops.vllm.rocm_aiter_topk_softmax(
+ topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
+ )
+ return topk_weights, topk_indices
+
+ @staticmethod
+ def biased_grouped_topk(
+ gating_output: torch.Tensor,
+ correction_bias: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ routed_scaling_factor: float = 1.0,
+ ) -> None:
+ torch.ops.vllm.rocm_aiter_biased_grouped_topk(
+ gating_output,
+ correction_bias,
+ topk_weights,
+ topk_ids,
+ num_expert_group,
+ topk_group,
+ need_renorm,
+ routed_scaling_factor,
+ )
+
+ @staticmethod
+ def grouped_topk(
+ gating_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_expert_group: int,
+ topk_group: int,
+ need_renorm: bool,
+ scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0,
+ ) -> None:
+ torch.ops.vllm.rocm_aiter_grouped_topk(
+ gating_output,
+ topk_weights,
+ topk_ids,
+ num_expert_group,
+ topk_group,
+ need_renorm,
+ scoring_func,
+ routed_scaling_factor,
+ )
+
+ @staticmethod
+ def mla_decode_fwd(
+ q: torch.Tensor,
+ kv_buffer: torch.Tensor,
+ o: torch.Tensor,
+ sm_scale: float,
+ qo_indptr: torch.Tensor,
+ max_seqlen_qo: int,
+ kv_indptr: torch.Tensor | None = None,
+ kv_indices: torch.Tensor | None = None,
+ kv_last_page_lens: torch.Tensor | None = None,
+ logit_cap: float = 0.0,
+ ):
+ torch.ops.vllm.rocm_aiter_mla_decode_fwd(
+ q,
+ kv_buffer.view(-1, 1, 1, q.shape[-1]),
+ o,
+ qo_indptr,
+ max_seqlen_qo,
+ kv_indptr,
+ kv_indices,
+ kv_last_page_lens,
+ sm_scale=sm_scale,
+ logit_cap=logit_cap,
+ )
+
+ @staticmethod
+ def triton_fp4_gemm_dynamic_qaunt(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ out_dtype: torch.dtype | None = torch.bfloat16,
+ x_scales: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
+
+ if x_scales is None:
+ x_q, x_s = dynamic_mxfp4_quant(x)
+ else:
+ x_q = x
+ x_s = x_scales
+
+ y = torch.empty(
+ x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
+ )
+
+ gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
+ return y
+
+ @staticmethod
+ def triton_rotary_embed(
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos_sin_cache: torch.Tensor,
+ head_size: int,
+ rotary_dim: int,
+ is_neox_style: bool,
+ ):
+ from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
+
+ num_tokens = positions.numel()
+ cos, sin = cos_sin_cache.chunk(2, dim=-1)
+ query_shape = query.shape
+ key_shape = key.shape
+ rotate_style = 0 if is_neox_style else 1
+
+ query = query.view(num_tokens, -1, head_size)
+ key = key.view(num_tokens, -1, head_size)
+ query_ = query[..., :rotary_dim]
+ key_ = key[..., :rotary_dim]
+ positions = positions.view(*query.shape[:1])
+ rope_cached_thd_positions_2c_fwd_inplace(
+ positions,
+ sin,
+ cos,
+ query_,
+ key_,
+ rotate_style,
+ reuse_freqs_front_part=True,
+ is_nope_first=False,
+ )
+ query = query.view(query_shape)
+ key = key.view(key_shape)
+
+ @staticmethod
+ def triton_fp8_bmm(
+ X: torch.Tensor,
+ WQ: torch.Tensor,
+ w_scale: torch.Tensor,
+ group_size: int = 128,
+ bias: torch.Tensor | None = None,
+ dtype: torch.dtype | None = torch.bfloat16,
+ splitK: int | None = None,
+ YQ: torch.Tensor | None = None,
+ transpose_bm: bool | None = False,
+ config: dict | None = None,
+ ) -> torch.Tensor:
+ # ruff: noqa: E501 # isort: skip
+ from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
+ batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
+ )
+
+ return aiter_triton_fp8_bmm(
+ X,
+ WQ,
+ w_scale,
+ group_size=group_size,
+ bias=bias,
+ dtype=dtype,
+ splitK=splitK,
+ YQ=YQ,
+ transpose_bm=transpose_bm,
+ config=config,
+ )
+
+ @staticmethod
+ def triton_gemm_a8w8_blockscale(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ block_size: list[int],
+ output_dtype: torch.dtype = torch.float16,
+ ) -> torch.Tensor:
+ from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
+
+ return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
+
+ @staticmethod
+ def per_1x128_fp8_quant(
+ input_2d: torch.Tensor,
+ ) -> tuple[torch.Tensor, ...]:
+ """Only applies quantization method for fp8 data type only."""
+ from aiter import QuantType, dtypes, get_hip_quant
+
+ aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
+ return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
+
+ @staticmethod
+ def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
+ return (n, k) in [
+ (1024, 8192),
+ (2112, 7168),
+ (3072, 1536),
+ (32768, 8192),
+ (4096, 7168),
+ (4608, 7168),
+ (512, 7168),
+ (7168, 2048),
+ (7168, 256),
+ (8192, 1024),
+ (8192, 32768),
+ ]
+
+ @staticmethod
+ def shuffle_weight(
+ self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
+ ) -> torch.Tensor:
+ from aiter.ops.shuffle import shuffle_weight
+
+ return shuffle_weight(tensor, layout=layout)
+
+ @staticmethod
+ def shuffle_weights(
+ *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
+ ) -> tuple[torch.Tensor, ...]:
+ """
+ Applies shuffle_weight function from AITER to each
+ input tensor and returns them.
+
+ Rearranges (shuffles) the input tensor/s
+ into a specified block layout for optimized computation.
+
+ Args:
+ *tensors: Variable number of torch.Tensor objects.
+ layout: A pair of integers specifying the block sizes used to divide
+ the tensors during shuffling. Default is (16, 16).
+
+ Returns:
+ A Tuple of shuffled tensors.
+ """
+ from aiter.ops.shuffle import shuffle_weight
+
+ return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
+
+
+rocm_aiter_ops.register_ops_once()
diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py
deleted file mode 100644
index 6308f63cc4e7..000000000000
--- a/vllm/attention/ops/rocm_aiter_mla.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-
-import torch
-
-from vllm.platforms import current_platform
-from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
-
-
-def get_aiter_mla_metadata(
- max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
-) -> tuple[torch.Tensor, ...]:
- paged_kv_indices = torch.zeros(
- max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
- )
- paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
- paged_kv_last_page_lens = torch.full(
- (max_batch_size,), block_size, dtype=torch.int32
- )
- qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
- return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
-
-
-def aiter_mla_decode_fwd(
- q: torch.Tensor,
- kv_buffer: torch.Tensor,
- o: torch.Tensor,
- sm_scale: float,
- qo_indptr: torch.Tensor,
- max_seqlen_qo: int,
- kv_indptr: torch.Tensor | None = None,
- kv_indices: torch.Tensor | None = None,
- kv_last_page_lens: torch.Tensor | None = None,
- logit_cap: float = 0.0,
-):
- torch.ops.vllm.rocm_aiter_mla_decode_fwd(
- q,
- kv_buffer.view(-1, 1, 1, q.shape[-1]),
- o,
- qo_indptr,
- max_seqlen_qo,
- kv_indptr,
- kv_indices,
- kv_last_page_lens,
- sm_scale=sm_scale,
- logit_cap=logit_cap,
- )
-
-
-def mla_decode_fwd_impl(
- q: torch.Tensor,
- kv_buffer: torch.Tensor,
- o: torch.Tensor,
- qo_indptr: torch.Tensor,
- max_seqlen_qo: int,
- kv_indptr: torch.Tensor | None = None,
- kv_indices: torch.Tensor | None = None,
- kv_last_page_lens: torch.Tensor | None = None,
- sm_scale: float = 1.0,
- logit_cap: float = 0.0,
-) -> None:
- from aiter.mla import mla_decode_fwd
-
- mla_decode_fwd(
- q,
- kv_buffer.view(-1, 1, 1, q.shape[-1]),
- o,
- qo_indptr,
- kv_indptr,
- kv_indices,
- kv_last_page_lens,
- max_seqlen_qo,
- sm_scale=sm_scale,
- logit_cap=logit_cap,
- )
-
-
-def mla_decode_fwd_fake(
- q: torch.Tensor,
- kv_buffer: torch.Tensor,
- o: torch.Tensor,
- qo_indptr: torch.Tensor,
- max_seqlen_qo: int,
- kv_indptr: torch.Tensor | None = None,
- kv_indices: torch.Tensor | None = None,
- kv_last_page_lens: torch.Tensor | None = None,
- sm_scale: float = 1.0,
- logit_cap: float = 0.0,
-) -> None:
- pass
-
-
-if current_platform.is_rocm():
- if is_torch_equal_or_newer("2.7.0"):
- tags = ()
- else:
- tags = ((torch.Tag.needs_fixed_stride_order,),)
- direct_register_custom_op(
- op_name="rocm_aiter_mla_decode_fwd",
- op_func=mla_decode_fwd_impl,
- mutates_args=["o"],
- fake_impl=mla_decode_fwd_fake,
- tags=tags,
- )
diff --git a/vllm/envs.py b/vllm/envs.py
index 078e5c38f0f4..30c62e90e9fb 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -109,7 +109,7 @@
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
- VLLM_ROCM_USE_TRITON_ROPE: bool = False
+ VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
@@ -926,8 +926,8 @@ def get_vllm_port() -> int | None:
),
# Whether to use aiter rope.
# By default is disabled.
- "VLLM_ROCM_USE_TRITON_ROPE": lambda: (
- os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
+ "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
+ os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
@@ -1589,7 +1589,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
- "VLLM_ROCM_USE_TRITON_ROPE",
+ "VLLM_ROCM_USE_AITER_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 7ad3ce1397b3..2e042d85fcfc 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -14,6 +14,7 @@
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@@ -55,8 +56,6 @@
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
-from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
-
logger = init_logger(__name__)
@@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
-def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
- if is_rocm_aiter_moe_enabled():
- from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
-
- return rocm_aiter_topk_softmax
+def dispatch_topk_func(
+ use_rocm_aiter: bool = False,
+) -> Callable[..., tuple[torch.Tensor, ...]]:
+ if use_rocm_aiter:
+ return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
@@ -1121,7 +1120,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
- topk_func = dispatch_topk_func()
+ topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index e69ead074c50..45b0f50a7997 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -13,6 +13,7 @@
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
@@ -41,8 +42,6 @@
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
- is_rocm_aiter_fusion_shared_expert_enabled,
- is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
@@ -92,13 +91,11 @@ def _eplb_map_to_physical_and_record(
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
+from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
+from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
+ rocm_aiter_grouped_topk,
+)
-if is_rocm_aiter_moe_enabled():
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
- rocm_aiter_grouped_topk as grouped_topk_aiter,
- )
-else:
- from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
@@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
- self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
+
+ self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -620,13 +618,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
- # Lazy import to avoid importing triton.
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- shuffle_weights,
- )
if self.rocm_aiter_moe_enabled:
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -1002,6 +996,7 @@ def determine_expert_map(
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0,
+ return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
@@ -1064,7 +1059,7 @@ def determine_expert_map(
)
expert_mask = None
- if is_rocm_aiter_moe_enabled():
+ if return_expert_mask:
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
@@ -1292,14 +1287,18 @@ def __init__(
self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
+ self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
+ self.aiter_fmoe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
+
self.num_fused_shared_experts = (
n_shared_experts
- if n_shared_experts is not None
- and is_rocm_aiter_fusion_shared_expert_enabled()
+ if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
if (
- not is_rocm_aiter_fusion_shared_expert_enabled()
+ not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
):
raise ValueError(
@@ -1346,6 +1345,7 @@ def __init__(
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
+ return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
@@ -1570,13 +1570,16 @@ def update_expert_map(self):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
+ return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
- self._init_aiter_shared_experts_topK_buffer(
- vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
- )
+ if self.aiter_fmoe_shared_expert_enabled:
+ self._init_aiter_shared_experts_topK_buffer(
+ vllm_config=get_current_vllm_config(),
+ dp_size=get_dp_group().world_size,
+ )
def _load_per_tensor_weight_scale(
self,
@@ -1753,20 +1756,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
- if is_rocm_aiter_fusion_shared_expert_enabled():
- if self.num_fused_shared_experts > 0:
- init_aiter_topK_meta_data(
- n_routed_experts=self.global_num_experts,
- n_shared_experts=self.num_fused_shared_experts,
- top_k=self.top_k,
- tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
- tp_size=self.ep_size if self.use_ep else self.tp_size,
- shared_experts_score=1.0,
- max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
- * dp_size,
- is_EP=self.use_ep,
- )
- self.local_num_experts += self.num_fused_shared_experts
+ if self.num_fused_shared_experts > 0:
+ init_aiter_topK_meta_data(
+ n_routed_experts=self.global_num_experts,
+ n_shared_experts=self.num_fused_shared_experts,
+ top_k=self.top_k,
+ tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
+ tp_size=self.ep_size if self.use_ep else self.tp_size,
+ shared_experts_score=1.0,
+ max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
+ * dp_size,
+ is_EP=self.use_ep,
+ )
+ self.local_num_experts += self.num_fused_shared_experts
@overload
def weight_loader(
@@ -2208,15 +2210,16 @@ def select_experts(
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
- if is_rocm_aiter_moe_enabled():
- if not is_rocm_aiter_fusion_shared_expert_enabled():
+ if rocm_aiter_ops.is_fused_moe_enabled():
+ if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
- grouped_topk_aiter,
+ rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
+
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -2448,7 +2451,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
- if not is_rocm_aiter_moe_enabled()
+ if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@@ -2612,7 +2615,7 @@ def forward_impl(
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
- if not is_rocm_aiter_moe_enabled()
+ if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index e18514ad43f6..8f05828d74f5 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -1,17 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
-from functools import cache, lru_cache
+from functools import lru_cache
import torch
-from vllm import envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
)
-from vllm.platforms import current_platform
-from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):
@@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
GELU = 1
-@cache
-def is_rocm_aiter_moe_enabled() -> bool:
- return (
- current_platform.is_rocm()
- and envs.VLLM_ROCM_USE_AITER_MOE
- and envs.VLLM_ROCM_USE_AITER
- )
-
-
-@cache
-def use_mxfp4_aiter_moe() -> bool:
- return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
-
-
-@cache
-def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
- return (
- envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
- )
-
-
aiter_topK_meta_data = None
@@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
-def rocm_aiter_asm_moe_tkw1_impl(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- fc1_scale: torch.Tensor | None = None,
- fc2_scale: torch.Tensor | None = None,
- fc1_smooth_scale: torch.Tensor | None = None,
- fc2_smooth_scale: torch.Tensor | None = None,
- a16: bool = False,
- per_tensor_quant_scale: torch.Tensor | None = None,
- expert_mask: torch.Tensor | None = None,
- activation_method: int = ActivationMethod.SILU.value,
-) -> torch.Tensor:
- from aiter import ActivationType
- from aiter.fused_moe_bf16_asm import asm_moe_tkw1
-
- activation = ActivationType(activation_method)
-
- return asm_moe_tkw1(
- hidden_states,
- w1,
- w2,
- topk_weights,
- topk_ids,
- fc1_scale=fc1_scale,
- fc2_scale=fc2_scale,
- fc1_smooth_scale=fc1_smooth_scale,
- fc2_smooth_scale=fc2_smooth_scale,
- a16=a16,
- per_tensor_quant_scale=per_tensor_quant_scale,
- expert_mask=expert_mask,
- activation=activation,
- )
-
-
-def rocm_aiter_asm_moe_tkw1_fake(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- fc1_scale: torch.Tensor | None = None,
- fc2_scale: torch.Tensor | None = None,
- fc1_smooth_scale: torch.Tensor | None = None,
- fc2_smooth_scale: torch.Tensor | None = None,
- a16: bool = False,
- per_tensor_quant_scale: torch.Tensor | None = None,
- expert_mask: torch.Tensor | None = None,
- activation_method: int = ActivationMethod.SILU.value,
-) -> torch.Tensor:
- return torch.empty_like(hidden_states)
-
-
-def rocm_aiter_topk_softmax_impl(
- topk_weights: torch.Tensor,
- topk_indices: torch.Tensor,
- token_expert_indices: torch.Tensor,
- gating_output: torch.Tensor,
- renormalize: bool,
-) -> None:
- from aiter import topk_softmax
-
- topk_softmax(
- topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
- )
-
-
-def rocm_aiter_topk_softmax_fake(
- topk_weights: torch.Tensor,
- topk_indices: torch.Tensor,
- token_expert_indices: torch.Tensor,
- gating_output: torch.Tensor,
- renormalize: bool,
-) -> None:
- pass
-
-
-def rocm_aiter_biased_grouped_topk_impl(
- gating_output: torch.Tensor,
- correction_bias: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_expert_group: int,
- topk_group: int,
- need_renorm: bool,
- routed_scaling_factor: float = 1.0, # mul to topk_weights
-) -> None:
- from aiter import biased_grouped_topk
-
- biased_grouped_topk(
- gating_output,
- correction_bias,
- topk_weights,
- topk_ids,
- num_expert_group,
- topk_group,
- need_renorm,
- routed_scaling_factor,
- )
-
-
-def rocm_aiter_biased_grouped_topk_fake(
- gating_output: torch.Tensor,
- correction_bias: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_expert_group: int,
- topk_group: int,
- need_renorm: bool,
- routed_scaling_factor: float = 1.0, # mul to topk_weights
-) -> None:
- pass
-
-
-def rocm_aiter_grouped_topk_impl(
- gating_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_expert_group: int,
- topk_group: int,
- need_renorm: bool,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0, # mul to topk_weights
-) -> None:
- from aiter import grouped_topk
-
- grouped_topk(
- gating_output,
- topk_weights,
- topk_ids,
- num_expert_group,
- topk_group,
- need_renorm,
- scoring_func,
- routed_scaling_factor,
- )
-
-
-def rocm_aiter_grouped_topk_fake(
- gating_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_expert_group: int,
- topk_group: int,
- need_renorm: bool,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0, # mul to topk_weights
-) -> None:
- pass
-
-
-def rocm_aiter_fused_moe_impl(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weight: torch.Tensor,
- topk_ids: torch.Tensor,
- expert_mask: torch.Tensor | None = None,
- activation_method: int = ActivationMethod.SILU.value,
- quant_method: int = QuantMethod.NO.value,
- doweight_stage1: bool = False,
- w1_scale: torch.Tensor | None = None,
- w2_scale: torch.Tensor | None = None,
- a1_scale: torch.Tensor | None = None,
- a2_scale: torch.Tensor | None = None,
-) -> torch.Tensor:
- from aiter import ActivationType, QuantType
- from aiter.fused_moe import fused_moe
-
- activation = ActivationType(activation_method)
- quant_type = QuantType(quant_method)
-
- return fused_moe(
- hidden_states,
- w1,
- w2,
- topk_weight,
- topk_ids,
- expert_mask,
- activation,
- quant_type,
- doweight_stage1,
- w1_scale,
- w2_scale,
- a1_scale,
- a2_scale,
- )
-
-
-def rocm_aiter_fused_moe_fake(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weight: torch.Tensor,
- topk_ids: torch.Tensor,
- expert_mask: torch.Tensor | None = None,
- activation_method: int = ActivationMethod.SILU.value,
- quant_method: int = QuantMethod.NO.value,
- doweight_stage1: bool = False,
- w1_scale: torch.Tensor | None = None,
- w2_scale: torch.Tensor | None = None,
- a1_scale: torch.Tensor | None = None,
- a2_scale: torch.Tensor | None = None,
-) -> torch.Tensor:
- return torch.empty_like(hidden_states)
-
-
-if current_platform.is_rocm():
- direct_register_custom_op(
- op_name="rocm_aiter_asm_moe_tkw1",
- op_func=rocm_aiter_asm_moe_tkw1_impl,
- fake_impl=rocm_aiter_asm_moe_tkw1_fake,
- )
-
- direct_register_custom_op(
- op_name="rocm_aiter_fused_moe",
- op_func=rocm_aiter_fused_moe_impl,
- fake_impl=rocm_aiter_fused_moe_fake,
- )
-
- direct_register_custom_op(
- op_name="rocm_aiter_topk_softmax",
- op_func=rocm_aiter_topk_softmax_impl,
- mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
- fake_impl=rocm_aiter_topk_softmax_fake,
- )
-
- direct_register_custom_op(
- op_name="rocm_aiter_biased_grouped_topk",
- op_func=rocm_aiter_biased_grouped_topk_impl,
- mutates_args=["topk_weights", "topk_ids"],
- fake_impl=rocm_aiter_biased_grouped_topk_fake,
- )
-
- direct_register_custom_op(
- op_name="rocm_aiter_grouped_topk",
- op_func=rocm_aiter_grouped_topk_impl,
- mutates_args=["topk_weights", "topk_ids"],
- fake_impl=rocm_aiter_grouped_topk_fake,
- )
-
-
def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
- if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
+ if (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ and num_fused_shared_experts > 0
+ ):
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
@@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None:
- torch.ops.vllm.rocm_aiter_biased_grouped_topk(
+ rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
@@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
)
else:
assert scoring_func == "softmax" or scoring_func == "sigmoid"
- torch.ops.vllm.rocm_aiter_grouped_topk(
+ rocm_aiter_ops.grouped_topk(
gating_output,
topk_weights,
topk_ids,
@@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
routed_scaling_factor=routed_scaling_factor,
)
- if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
+ if (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ and num_fused_shared_experts > 0
+ ):
return total_topk_weights, total_topk_ids
return topk_weights, topk_ids
@@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
- return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
+ return rocm_aiter_ops.asm_moe_tkw1(
hidden_states,
w1,
w2,
@@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
-
+ # quark moe for mxfp4 w_dtype
+ if quant_config.use_mxfp4_w4a16:
+ quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
@@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
- return torch.ops.vllm.rocm_aiter_fused_moe(
+ return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
w2,
@@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
)
-
-
-def rocm_aiter_topk_softmax(
- topk_weights: torch.Tensor,
- topk_indices: torch.Tensor,
- token_expert_indices: torch.Tensor,
- gating_output: torch.Tensor,
- renormalize: bool,
-) -> tuple[torch.Tensor, ...]:
- torch.ops.vllm.rocm_aiter_topk_softmax(
- topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
- )
- return topk_weights, topk_indices
-
-
-def shuffle_weights(
- *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
-) -> tuple[torch.Tensor, ...]:
- """
- Applies shuffle_weight function from AITER to each
- input tensor and returns them.
-
- Rearranges (shuffles) the input tensor/s
- into a specified block layout for optimized computation.
-
- Args:
- *tensors: Variable number of torch.Tensor objects.
- layout: A pair of integers specifying the block sizes used to divide
- the tensors during shuffling. Default is (16, 16).
-
- Returns:
- A Tuple of shuffled tensors.
- """
- from aiter.ops.shuffle import shuffle_weight
-
- return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index a883ac81f41e..8cc374ac9155 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -6,18 +6,13 @@
import torch.nn as nn
import torch.nn.functional as F
-import vllm.envs as envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
-from vllm.utils.torch_utils import direct_register_custom_op
-
-
-def is_rocm_aiter_rmsnorm_enabled() -> bool:
- return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
def rms_norm(
@@ -58,80 +53,34 @@ def fused_add_rms_norm(
return x, residual
-def rocm_aiter_rms_norm_impl(
- x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
+def poly_norm(
+ x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
- import aiter as rocm_aiter
-
- if x.dim() > 2:
- x_original_shape = x.shape
- x = x.reshape(-1, x_original_shape[-1])
- x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
- return x.reshape(x_original_shape)
-
- return rocm_aiter.rms_norm(x, weight, variance_epsilon)
-
+ from vllm import _custom_ops as ops
-def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
- x: torch.Tensor,
- residual: torch.Tensor,
- weight: torch.Tensor,
- variance_epsilon: float,
-) -> tuple[torch.Tensor, torch.Tensor]:
- import aiter as rocm_aiter
-
- residual_out = torch.empty_like(residual)
- output = torch.empty_like(x)
- rocm_aiter.rmsnorm2d_fwd_with_add(
- output, # output
- x, # input
- residual, # residual input
- residual_out, # residual output
+ out = torch.empty_like(x)
+ ops.poly_norm(
+ out,
+ x,
weight,
+ bias,
variance_epsilon,
)
- return output, residual_out
-
-
-def rocm_aiter_rms_norm_fake(
- x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
-) -> torch.Tensor:
- return torch.empty_like(x)
-
-
-def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
- x: torch.Tensor,
- residual: torch.Tensor,
- weight: torch.Tensor,
- variance_epsilon: float,
-) -> tuple[torch.Tensor, torch.Tensor]:
- return torch.empty_like(x), torch.empty_like(residual)
-
-
-if current_platform.is_rocm():
- direct_register_custom_op(
- op_name="rocm_aiter_rms_norm",
- op_func=rocm_aiter_rms_norm_impl,
- fake_impl=rocm_aiter_rms_norm_fake,
- )
-
- direct_register_custom_op(
- op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
- op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
- fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
- )
+ return out
-def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
- use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
+def dispatch_rocm_rmsnorm_func(
+ with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
+):
+ use_aiter = use_aiter and dtype in [
torch.float16,
torch.bfloat16,
]
if use_aiter and with_fused_add:
- return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
+ return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter:
- return torch.ops.vllm.rocm_aiter_rms_norm
+ return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation
if with_fused_add:
@@ -169,11 +118,14 @@ def __init__(
self.weight = nn.Parameter(self.weight)
if current_platform.is_rocm():
+ aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
- with_fused_add=False, dtype=weight_dtype
+ with_fused_add=False,
+ dtype=weight_dtype,
+ use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
- with_fused_add=True, dtype=weight_dtype
+ with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
@staticmethod
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index d95d49eddfe3..d32ae6674ee6 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -12,6 +12,7 @@
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -582,11 +583,8 @@ def __init__(
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_moe_enabled,
- )
- self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
+ self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
@@ -829,12 +827,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
- shuffle_weights,
- )
-
# reshaping weights is required for aiter moe kernel.
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
index ee431c9148b8..6da136cbc8f6 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@@ -7,12 +7,12 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
- check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool)
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
- self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
+ self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index ce40645782e5..e4e1cbff712f 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -12,6 +12,7 @@
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -56,7 +57,6 @@
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
- check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@@ -369,7 +369,7 @@ def __init__(self, quant_config: Fp8Config):
if vllm_is_batch_invariant():
self.use_marlin = False
- self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
+ self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
@@ -869,12 +869,8 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_moe_enabled,
- shuffle_weights,
- )
- self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
+ self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
@@ -916,7 +912,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -962,7 +958,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@@ -1042,7 +1038,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
start += shard_size
if self.rocm_aiter_moe_enabled:
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
index a19396a162bc..f5cd91469b78 100644
--- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
+++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
@@ -4,54 +4,14 @@
import torch
-import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
-from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
-def rocm_aiter_gemm_w8a8_impl(
- A: torch.Tensor,
- B: torch.Tensor,
- As: torch.Tensor,
- Bs: torch.Tensor,
- bias: torch.Tensor | None = None,
- output_dtype: torch.dtype = torch.float16,
-) -> torch.Tensor:
- from aiter import gemm_a8w8_CK
-
- # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
- # a to be [M, K]
- # b to be [N, K]
- # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
- return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
-
-
-def rocm_aiter_gemm_w8a8_fake(
- A: torch.Tensor,
- B: torch.Tensor,
- As: torch.Tensor,
- Bs: torch.Tensor,
- bias: torch.Tensor | None = None,
- output_dtype: torch.dtype = torch.float16,
-) -> torch.Tensor:
- m = A.shape[0]
- n = B.shape[0]
- Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
- return Y
-
-
-if current_platform.is_rocm():
- direct_register_custom_op(
- op_name="rocm_aiter_gemm_w8a8",
- op_func=rocm_aiter_gemm_w8a8_impl,
- fake_impl=rocm_aiter_gemm_w8a8_fake,
- )
-
-
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@@ -75,7 +35,7 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
- if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
+ if not (rocm_aiter_ops.is_linear_enabled()):
return (
False,
"AiterScaledMMLinearKernel is disabled. "
@@ -157,6 +117,4 @@ def apply_weights(
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
- return torch.ops.vllm.rocm_aiter_gemm_w8a8(
- x_q, w_q.t(), x_s, w_s, bias, out_dtype
- )
+ return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 8825611051e5..dc215258790e 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -8,6 +8,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@@ -21,10 +22,6 @@
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
-from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_moe_enabled,
- use_mxfp4_aiter_moe,
-)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
@@ -122,7 +119,7 @@ def __init__(
if current_platform.is_rocm():
self.use_marlin = False
- self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
+ self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def create_weights(
self,
@@ -309,12 +306,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
- shuffle_weights,
- )
-
# reshaping weights is required for aiter moe kernel.
- shuffled_w13, shuffled_w2 = shuffle_weights(
+ shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -469,13 +462,15 @@ def __init__(
"not implemented. Please open an issue."
)
+ self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
+
self.emulate = not current_platform.supports_mx() or not (
- use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
+ self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
- f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
+ f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
@@ -644,28 +639,18 @@ def apply(
)
if not self.emulate:
- from aiter import ActivationType, QuantType
- from aiter.fused_moe import fused_moe
-
- aiter_acts = {
- ActivationType.No.name.lower(): ActivationType.No,
- ActivationType.Silu.name.lower(): ActivationType.Silu,
- ActivationType.Gelu.name.lower(): ActivationType.Gelu,
- }
- assert activation in aiter_acts, (
- f"Aiter CK fp4 MoE doesn't support activation {activation}"
+ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
+ rocm_aiter_fused_experts,
)
- out = fused_moe(
+
+ out = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
- topk_weights,
- topk_ids,
- quant_type=QuantType.per_1x32,
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- activation=aiter_acts[activation],
- doweight_stage1=False,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=activation,
+ quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
index c25c522dea55..007e78e68d5c 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
@@ -31,6 +31,13 @@
logger = init_logger(__name__)
+# TODO: move registration of custom op to aiter_ops.py
+# `from vllm._aiter_ops import rocm_aiter_ops`
+# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
+# for envs checks which does not require @cache anymore.
+# triton kernel is torch compile compatible.
+# does not require direct registeration.
+# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index 7fecda2166ef..63726c07b7d1 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -12,6 +12,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -68,78 +69,6 @@ def cutlass_scaled_mm(
)
-def rocm_aiter_gemm_w8a8_blockscale_impl(
- input_2d: torch.Tensor,
- weight: torch.Tensor,
- input_scale: torch.Tensor,
- weight_scale: torch.Tensor,
- group_size: int,
- output_dtype: torch.dtype = torch.float16,
-) -> torch.Tensor:
- def is_aiter_triton_kernel_tuned(n, k):
- return (n, k) in [
- (1024, 8192),
- (2112, 7168),
- (3072, 1536),
- (32768, 8192),
- (4096, 7168),
- (4608, 7168),
- (512, 7168),
- (7168, 2048),
- (7168, 256),
- (8192, 1024),
- (8192, 32768),
- ]
-
- n, k = weight.shape
- if input_scale is not None:
- q_input = input_2d
- elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
- from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
-
- # MI350 case uses triton kernel
- q_input, input_scale = per_token_group_quant_fp8(
- input_2d,
- group_size,
- column_major_scales=False,
- use_ue8m0=False,
- )
- else:
- # MI300 uses tuned AITER ASM/C++ kernel
- import aiter as rocm_aiter
- from aiter import gemm_a8w8_blockscale, get_hip_quant
-
- aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
- q_input, input_scale = aiter_per1x128_quant(
- input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
- )
-
- return gemm_a8w8_blockscale(
- q_input, weight, input_scale, weight_scale, dtype=output_dtype
- )
-
-
-def rocm_aiter_gemm_w8a8_blockscale_fake(
- input_2d: torch.Tensor,
- weight: torch.Tensor,
- input_scale: torch.Tensor,
- weight_scale: torch.Tensor,
- group_size: int,
- output_dtype: torch.dtype = torch.float16,
-) -> torch.Tensor:
- m = input_2d.shape[0]
- n = weight.shape[0]
- return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
-
-
-if current_platform.is_rocm():
- direct_register_custom_op(
- op_name="rocm_aiter_gemm_w8a8_blockscale",
- op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
- fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
- )
-
-
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
@@ -385,14 +314,40 @@ def _run_aiter(
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
- return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
- input_2d,
- weight,
- input_scale,
- weight_scale,
- self.act_quant_group_shape.col,
- input_2d.dtype,
- )
+
+ n, k = weight.shape
+ if input_scale is not None:
+ q_input = input_2d
+
+ # MI350 case uses triton kernel
+ if (
+ not current_platform.is_fp8_fnuz()
+ and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
+ ):
+ q_input, input_scale = per_token_group_quant_fp8(
+ input_2d,
+ self.act_quant_group_shape.col,
+ column_major_scales=False,
+ use_ue8m0=False,
+ )
+ return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
+ q_input,
+ weight,
+ input_scale,
+ weight_scale,
+ input_2d.dtype,
+ )
+
+ # MI300 uses tuned AITER ASM/C++ kernel
+ else:
+ q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
+ return rocm_aiter_ops.gemm_w8a8_blockscale(
+ q_input,
+ weight,
+ input_scale,
+ weight_scale,
+ input_2d.dtype,
+ )
def _run_triton(
self,
@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
-def check_aiter_fp8_linear_support() -> bool:
- """AITER is only supported on ROCm for MI3XX"""
- return (
- current_platform.is_rocm()
- and envs.VLLM_ROCM_USE_AITER
- and envs.VLLM_ROCM_USE_AITER_LINEAR
- )
-
-
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
index 380431e86435..7fe902807a74 100644
--- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
@@ -472,7 +472,7 @@ def apply(
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
- per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
+ per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py
index 91276320df4d..2ef54e75df44 100644
--- a/vllm/model_executor/layers/rotary_embedding/base.py
+++ b/vllm/model_executor/layers/rotary_embedding/base.py
@@ -4,13 +4,10 @@
import torch
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
-from .rocm_aiter_rope_ops import (
- is_rocm_triton_rotary_embedding_enabled,
- rocm_aiter_rotary_emb,
-)
@CustomOp.register("rotary_embedding")
@@ -48,8 +45,8 @@ def __init__(
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
- self.is_rocm_triton_rotary_embedding_enabled = (
- is_rocm_triton_rotary_embedding_enabled()
+ self.is_rocm_triton_rotary_embed_enabled = (
+ rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
@@ -169,9 +166,9 @@ def forward_hip(
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
- if self.is_rocm_triton_rotary_embedding_enabled:
+ if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query)
- rocm_aiter_rotary_emb(
+ rocm_aiter_ops.triton_rotary_embed(
positions,
query,
key,
diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
index d9134f05fddf..e72834e473c1 100644
--- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
@@ -146,6 +146,15 @@ def forward_native(
key = key_rot
return query, key
+ def forward_hip(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ return self.forward_native(positions, query, key, offsets)
+
def forward_cuda(
self,
positions: torch.Tensor,
diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
deleted file mode 100644
index a01d14f7b3a1..000000000000
--- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-
-import vllm.envs as envs
-from vllm.platforms import current_platform
-from vllm.utils.torch_utils import direct_register_custom_op
-
-
-def is_rocm_triton_rotary_embedding_enabled() -> bool:
- return (
- current_platform.is_rocm()
- and envs.VLLM_ROCM_USE_AITER
- and envs.VLLM_ROCM_USE_TRITON_ROPE
- )
-
-
-def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
- positions: torch.Tensor,
- sin: torch.Tensor,
- cos: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- rotate_style: int = 0,
- is_nope_first: bool = False,
-) -> None:
- import aiter.ops.triton.rope as ops
-
- ops.rope_cached_thd_positions_2c_fwd_inplace(
- query,
- key,
- cos,
- sin,
- positions,
- rotate_style,
- reuse_freqs_front_part=True,
- nope_first=is_nope_first,
- )
-
-
-def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
- positions: torch.Tensor,
- sin: torch.Tensor,
- cos: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- rotate_style: int = 0,
- is_nope_first: bool = False,
-) -> None:
- pass
-
-
-if is_rocm_triton_rotary_embedding_enabled():
- direct_register_custom_op(
- op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
- op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
- mutates_args=["key", "query"],
- fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
- dispatch_key=current_platform.dispatch_key,
- )
-
-
-def rocm_aiter_rotary_emb(
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- cos_sin_cache: torch.Tensor,
- head_size: int,
- rotary_dim: int,
- is_neox_style: bool,
-):
- num_tokens = positions.numel()
- cos, sin = cos_sin_cache.chunk(2, dim=-1)
- query_shape = query.shape
- key_shape = key.shape
- rotate_style = 0 if is_neox_style else 1
-
- query = query.view(num_tokens, -1, head_size)
- key = key.view(num_tokens, -1, head_size)
- query_ = query[..., :rotary_dim]
- key_ = key[..., :rotary_dim]
- positions = positions.view(*query.shape[:1])
- torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
- positions,
- sin,
- cos,
- query_,
- key_,
- rotate_style,
- False,
- )
- query = query.view(query_shape)
- key = key.view(key_shape)
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 63eaf63cc3c4..38189e17f7d8 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -33,6 +33,7 @@
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
@@ -50,10 +51,6 @@
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
-from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- is_rocm_aiter_fusion_shared_expert_enabled,
- is_rocm_aiter_moe_enabled,
-)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -294,10 +291,8 @@ def __init__(
self.physical_expert_start + self.n_local_physical_experts
)
- if (
- config.n_shared_experts is None
- or is_rocm_aiter_fusion_shared_expert_enabled()
- ):
+ self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
+ if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -330,14 +325,14 @@ def __init__(
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
- if not is_rocm_aiter_moe_enabled()
+ if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
- if is_rocm_aiter_fusion_shared_expert_enabled()
+ if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
else None,
)
@@ -371,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
- if not is_rocm_aiter_moe_enabled():
+ if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
@@ -1428,6 +1423,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rocm_aiter_moe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@@ -1456,7 +1454,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
- if is_rocm_aiter_fusion_shared_expert_enabled()
+ if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
@@ -1472,9 +1470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if spec_layer is not None:
continue # skip spec decode layers for main model
- is_fuse_shared_experts_layer = (
- is_rocm_aiter_fusion_shared_expert_enabled()
- and ("mlp.shared_experts" in name)
+ is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
+ "mlp.shared_experts" in name
)
for param_name, weight_name, shard_id in stacked_params_mapping:
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index 1abd6300036d..e6536a02a73d 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
+ from vllm._aiter_ops import rocm_aiter_ops
+
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
- and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
+ and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)
@@ -202,12 +204,15 @@ class RocmPlatform(Platform):
]
@classmethod
- def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
+ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from importlib.util import find_spec
+ from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
- if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
+ if rocm_aiter_ops.is_mha_enabled():
+ # Note: AITER FA is only supported for Qwen-VL models.
+ # TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
@@ -228,19 +233,23 @@ def get_attn_backend_cls(
has_sink,
use_sparse,
) -> str:
+ from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
- if use_mla:
- from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
- is_aiter_mla_enabled,
+
+ if not use_v1:
+ raise RuntimeError(
+ "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
+ "to select a supported backend."
)
+ if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
- if is_aiter_mla_enabled() or block_size == 1
+ if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
@@ -265,12 +274,12 @@ def get_attn_backend_cls(
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
- envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
+ rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
- envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
+ rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 6c8145b6847d..102c4a288d13 100755
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -198,6 +198,7 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
@@ -271,28 +272,15 @@ class QueryLenSupport(Enum):
flashinfer_available = False
-def is_rocm_aiter_fp8bmm_enabled() -> bool:
- return (
- current_platform.is_rocm()
- and envs.VLLM_ROCM_USE_AITER_FP8BMM
- and envs.VLLM_ROCM_USE_AITER
- )
-
-
-if is_rocm_aiter_fp8bmm_enabled():
- from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
- batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
- )
-
- def dynamic_per_batched_tensor_quant(
- x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
- ):
- DTYPE_MAX = torch.finfo(dtype).max
- min_val, max_val = x.aminmax()
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
- scale = DTYPE_MAX / amax
- x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
- return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
+def dynamic_per_batched_tensor_quant(
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
+):
+ DTYPE_MAX = torch.finfo(dtype).max
+ min_val, max_val = x.aminmax()
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
+ scale = DTYPE_MAX / amax
+ x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__)
@@ -1114,6 +1102,7 @@ def __init__(
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
+ self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
@@ -1163,7 +1152,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
- if is_rocm_aiter_fp8bmm_enabled():
+ if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@@ -1192,7 +1181,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
dtype=torch.bfloat16,
device=self.W_K.device,
)
- aiter_triton_fp8_bmm(
+ rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@@ -1201,7 +1190,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
dtype=torch.bfloat16,
device=self.W_V.device,
)
- aiter_triton_fp8_bmm(
+ rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@@ -1213,10 +1202,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
-
- if is_rocm_aiter_fp8bmm_enabled():
+ if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
- x = aiter_triton_fp8_bmm(
+ x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
# Convert from (B, N, V) to (B, N * V)
@@ -1576,7 +1564,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
- if is_rocm_aiter_fp8bmm_enabled():
+ if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@@ -1605,7 +1593,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
dtype=torch.bfloat16,
device=self.W_K.device,
)
- aiter_triton_fp8_bmm(
+ rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@@ -1614,7 +1602,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
dtype=torch.bfloat16,
device=self.W_V.device,
)
- aiter_triton_fp8_bmm(
+ rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@@ -1963,7 +1951,6 @@ def forward(
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
- # Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
@@ -1971,9 +1958,9 @@ def forward(
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
- if is_rocm_aiter_fp8bmm_enabled():
+ if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
- decode_ql_nope = aiter_triton_fp8_bmm(
+ decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,
diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
index 71eac84b6f06..87733c6af05c 100644
--- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py
@@ -6,9 +6,8 @@
import torch
-import vllm.envs as envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer
-from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
@@ -22,10 +21,6 @@
from vllm.v1.kv_cache_interface import AttentionSpec
-def is_aiter_mla_enabled() -> bool:
- return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
-
-
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
@@ -288,7 +283,7 @@ def _forward_decode(
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
- aiter_mla_decode_fwd(
+ rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,