diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 633e23eea33e..ee224e6922fb 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | -| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | +| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 014df1fa111f..c27cf2468ede 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,6 +6,8 @@ """ import functools +import importlib +import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -20,6 +22,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context @@ -412,14 +415,12 @@ def test_mixtral_moe( huggingface.""" # clear the cache before every test - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) + # Force reload aiter_ops to pick up the new environment variables. + if "rocm_aiter_ops" in sys.modules: + importlib.reload(rocm_aiter_ops) - is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 41419553aa83..9121284de85b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import ( @@ -15,9 +16,6 @@ dispatch_topk_func, vllm_topk_softmax, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_rocm_rmsnorm_func, @@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax, - ) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_topk_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_func(use_rocm_aiter) - assert topk_func == rocm_aiter_topk_softmax + if current_platform.is_rocm() and use_rocm_aiter: + assert topk_func == rocm_aiter_ops.topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.skipif( not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" ) def test_rms_norm_dispatch( - add_residual: bool, - dtype: torch.dtype, - use_rocm_aiter: str, - use_rocm_aiter_norm: str, - monkeypatch, + add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool ): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) should_use_rocm_aiter = ( current_platform.is_rocm() - and int(use_rocm_aiter) - and int(use_rocm_aiter_norm) + and use_rocm_aiter and dtype in RMS_NORM_SUPPORTED_DTYPES ) if add_residual and should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add elif should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + assert rms_norm_func == rocm_aiter_ops.rms_norm elif add_residual: assert rms_norm_func == fused_add_rms_norm else: diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py new file mode 100644 index 000000000000..9a4b5f3399be --- /dev/null +++ b/vllm/_aiter_ops.py @@ -0,0 +1,941 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer + + +def is_aiter_found() -> bool: + from importlib.util import find_spec + + return find_spec("aiter") is not None + + +# `find_spec` is not torch.compile compatible. +# In cases where aiter availability might have +# been checked in forward passes that are torch compiled. +# we keep this global outside to not cause torch compile breaks. +IS_AITER_FOUND = is_aiter_found() + + +def if_aiter_supported(func: Callable) -> Callable: + """Decorator that only executes the function if + ROCm AITER package is supported on gfx9 archs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # checks the platform, device arch and aiter library existance. + + from vllm.platforms.rocm import on_gfx9 + + if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND: + return func(*args, **kwargs) + else: + # Return None or do nothing if not supported + return None + + return wrapper + + +def _rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def _rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) + + +def _rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + from aiter import topk_softmax + + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + + +def _rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + pass + + +def _rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + from aiter import biased_grouped_topk + + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + +def _rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + is_softmax = scoring_func == "softmax" + from aiter import grouped_topk + + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + is_softmax, + routed_scaling_factor, + ) + + +def _rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + +def _rocm_aiter_mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +def _rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def _rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import rmsnorm2d_fwd_with_add + + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + +class rocm_aiter_ops: + _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + + @classmethod + @if_aiter_supported + def is_enabled(cls) -> bool: + """Verifies device specs and availability of aiter main env variable.""" + return cls._AITER_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._LINEAR_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_fp8_enaled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + + @classmethod + @if_aiter_supported + def is_rmsnorm_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + + @classmethod + @if_aiter_supported + def is_fused_moe_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FMOE_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_moe_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_mla_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MLA_ENABLED + + @classmethod + @if_aiter_supported + def is_mha_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MHA_ENABLED + + @classmethod + @if_aiter_supported + def is_pa_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_triton_unified_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_fp8bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM + + @classmethod + @if_aiter_supported + def is_triton_rotary_embed_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED + + @staticmethod + @if_aiter_supported + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + tags = ( + tuple() + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ) + + # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=_rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=_rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_moe", + op_func=_rocm_aiter_fused_moe_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=_rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=_rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=_rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=_rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=_rocm_aiter_mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=_rocm_aiter_mla_decode_fwd_fake, + tags=tags, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=_rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=_rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=_rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + @staticmethod + def rms_norm2d_with_add( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + x, residual, weight, variance_epsilon + ) + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + + @staticmethod + def gemm_w8a8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype) + + @staticmethod + def gemm_w8a8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + A, B, As, Bs, output_dtype + ) + + @staticmethod + def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation_method, + quant_method, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + @staticmethod + def asm_moe_tkw1( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale, + fc2_scale, + fc1_smooth_scale, + fc2_smooth_scale, + a16, + per_tensor_quant_scale, + expert_mask, + activation_method, + ) + + @staticmethod + def topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + ) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + return topk_weights, topk_indices + + @staticmethod + def biased_grouped_topk( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + @staticmethod + def grouped_topk( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) + + @staticmethod + def mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + logit_cap: float = 0.0, + ): + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + @staticmethod + def triton_fp4_gemm_dynamic_qaunt( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype | None = torch.bfloat16, + x_scales: torch.Tensor | None = None, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + @staticmethod + def triton_rotary_embed( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, + ): + from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + rope_cached_thd_positions_2c_fwd_inplace( + positions, + sin, + cos, + query_, + key_, + rotate_style, + reuse_freqs_front_part=True, + is_nope_first=False, + ) + query = query.view(query_shape) + key = key.view(key_shape) + + @staticmethod + def triton_fp8_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + group_size: int = 128, + bias: torch.Tensor | None = None, + dtype: torch.dtype | None = torch.bfloat16, + splitK: int | None = None, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, + ) + + return aiter_triton_fp8_bmm( + X, + WQ, + w_scale, + group_size=group_size, + bias=bias, + dtype=dtype, + splitK=splitK, + YQ=YQ, + transpose_bm=transpose_bm, + config=config, + ) + + @staticmethod + def triton_gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + @staticmethod + def per_1x128_fp8_quant( + input_2d: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Only applies quantization method for fp8 data type only.""" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + + @staticmethod + def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: + return (n, k) in [ + (1024, 8192), + (2112, 7168), + (3072, 1536), + (32768, 8192), + (4096, 7168), + (4608, 7168), + (512, 7168), + (7168, 2048), + (7168, 256), + (8192, 1024), + (8192, 32768), + ] + + @staticmethod + def shuffle_weight( + self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> torch.Tensor: + from aiter.ops.shuffle import shuffle_weight + + return shuffle_weight(tensor, layout=layout) + + @staticmethod + def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. + + Args: + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). + + Returns: + A Tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) + + +rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py deleted file mode 100644 index 6308f63cc4e7..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -import torch - -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer - - -def get_aiter_mla_metadata( - max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device -) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros( - max_batch_size * max_block_per_batch, dtype=torch.int32, device=device - ) - paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) - paged_kv_last_page_lens = torch.full( - (max_batch_size,), block_size, dtype=torch.int32 - ) - qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr - - -def aiter_mla_decode_fwd( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - logit_cap: float = 0.0, -): - torch.ops.vllm.rocm_aiter_mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_impl( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - from aiter.mla import mla_decode_fwd - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_fake( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - pass - - -if current_platform.is_rocm(): - if is_torch_equal_or_newer("2.7.0"): - tags = () - else: - tags = ((torch.Tag.needs_fixed_stride_order,),) - direct_register_custom_op( - op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags, - ) diff --git a/vllm/envs.py b/vllm/envs.py index 078e5c38f0f4..30c62e90e9fb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,7 +109,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True @@ -926,8 +926,8 @@ def get_vllm_port() -> int | None: ), # Whether to use aiter rope. # By default is disabled. - "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -1589,7 +1589,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_TRITON_ROPE", + "VLLM_ROCM_USE_AITER_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_AITER_TRITON_GEMM", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ad3ce1397b3..2e042d85fcfc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -55,8 +56,6 @@ from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer -from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled - logger = init_logger(__name__) @@ -1089,11 +1088,11 @@ def vllm_topk_softmax( return topk_weights, topk_indices -def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax - - return rocm_aiter_topk_softmax +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax return vllm_topk_softmax @@ -1121,7 +1120,7 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - topk_func = dispatch_topk_func() + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) topk_weights, topk_ids = topk_func( topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e69ead074c50..45b0f50a7997 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,6 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( @@ -41,8 +42,6 @@ ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( @@ -92,13 +91,11 @@ def _eplb_map_to_physical_and_record( return topk_ids eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) -if is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk_aiter, - ) -else: - from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: @@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -620,13 +618,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Padding the weight for better performance on ROCm layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - # Lazy import to avoid importing triton. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights, - ) if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -1002,6 +996,7 @@ def determine_expert_map( global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", num_fused_shared_experts: int = 0, + return_expert_mask: bool = False, ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and @@ -1064,7 +1059,7 @@ def determine_expert_map( ) expert_mask = None - if is_rocm_aiter_moe_enabled(): + if return_expert_mask: expert_mask = torch.ones( (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 ) @@ -1292,14 +1287,18 @@ def __init__( self.logical_replica_count: torch.Tensor | None = None # ROCm aiter shared experts fusion + self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.aiter_fmoe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + self.num_fused_shared_experts = ( n_shared_experts - if n_shared_experts is not None - and is_rocm_aiter_fusion_shared_expert_enabled() + if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled else 0 ) if ( - not is_rocm_aiter_fusion_shared_expert_enabled() + not self.aiter_fmoe_shared_expert_enabled and self.num_fused_shared_experts != 0 ): raise ValueError( @@ -1346,6 +1345,7 @@ def __init__( global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) @@ -1570,13 +1570,16 @@ def update_expert_map(self): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) - self._init_aiter_shared_experts_topK_buffer( - vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size - ) + if self.aiter_fmoe_shared_expert_enabled: + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), + dp_size=get_dp_group().world_size, + ) def _load_per_tensor_weight_scale( self, @@ -1753,20 +1756,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int ): - if is_rocm_aiter_fusion_shared_expert_enabled(): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts @overload def weight_loader( @@ -2208,15 +2210,16 @@ def select_experts( elif use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if is_rocm_aiter_moe_enabled(): - if not is_rocm_aiter_fusion_shared_expert_enabled(): + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): assert num_fused_shared_experts == 0 grouped_topk_impl = partial( - grouped_topk_aiter, + rocm_aiter_grouped_topk, num_fused_shared_experts=num_fused_shared_experts, ) else: grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, @@ -2448,7 +2451,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -2612,7 +2615,7 @@ def forward_impl( use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index e18514ad43f6..8f05828d74f5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache, lru_cache +from functools import lru_cache import torch -from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, ) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): @@ -37,27 +35,6 @@ class ActivationMethod(IntEnum): GELU = 1 -@cache -def is_rocm_aiter_moe_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_MOE - and envs.VLLM_ROCM_USE_AITER - ) - - -@cache -def use_mxfp4_aiter_moe() -> bool: - return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - - -@cache -def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: - return ( - envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() - ) - - aiter_topK_meta_data = None @@ -114,250 +91,6 @@ def init_aiter_topK_meta_data( aiter_topK_meta_data = (total_topk_weights, total_topk_ids) -def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - from aiter import ActivationType - from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - - activation = ActivationType(activation_method) - - return asm_moe_tkw1( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation, - ) - - -def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_topk_softmax_impl( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - from aiter import topk_softmax - - topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - - -def rocm_aiter_topk_softmax_fake( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - pass - - -def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import biased_grouped_topk - - biased_grouped_topk( - gating_output, - correction_bias, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - routed_scaling_factor, - ) - - -def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import grouped_topk - - grouped_topk( - gating_output, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - scoring_func, - routed_scaling_factor, - ) - - -def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_fused_moe_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - activation = ActivationType(activation_method) - quant_type = QuantType(quant_method) - - return fused_moe( - hidden_states, - w1, - w2, - topk_weight, - topk_ids, - expert_mask, - activation, - quant_type, - doweight_stage1, - w1_scale, - w2_scale, - a1_scale, - a2_scale, - ) - - -def rocm_aiter_fused_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_asm_moe_tkw1", - op_func=rocm_aiter_asm_moe_tkw1_impl, - fake_impl=rocm_aiter_asm_moe_tkw1_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_fused_moe", - op_func=rocm_aiter_fused_moe_impl, - fake_impl=rocm_aiter_fused_moe_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_topk_softmax", - op_func=rocm_aiter_topk_softmax_impl, - mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], - fake_impl=rocm_aiter_topk_softmax_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_biased_grouped_topk", - op_func=rocm_aiter_biased_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_biased_grouped_topk_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_grouped_topk", - op_func=rocm_aiter_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_grouped_topk_fake, - ) - - def rocm_aiter_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): assert aiter_topK_meta_data is not None, ( "AITER topK meta data is not initialized. " "Please ensure that init_aiter_topK_meta_data " @@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: - torch.ops.vllm.rocm_aiter_biased_grouped_topk( + rocm_aiter_ops.biased_grouped_topk( gating_output, e_score_correction_bias.to(gating_output.dtype), topk_weights, @@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk( ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" - torch.ops.vllm.rocm_aiter_grouped_topk( + rocm_aiter_ops.grouped_topk( gating_output, topk_weights, topk_ids, @@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk( routed_scaling_factor=routed_scaling_factor, ) - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -464,7 +203,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + return rocm_aiter_ops.asm_moe_tkw1( hidden_states, w1, w2, @@ -482,7 +221,9 @@ def rocm_aiter_fused_experts( else: quant_method = QuantMethod.NO.value - + # quark moe for mxfp4 w_dtype + if quant_config.use_mxfp4_w4a16: + quant_method = QuantMethod.BLOCK_1X32.value # w8a8 block-scaled if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( @@ -507,7 +248,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_fused_moe( + return rocm_aiter_ops.fused_moe( hidden_states, w1, w2, @@ -522,39 +263,3 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, ) - - -def rocm_aiter_topk_softmax( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - return topk_weights, topk_indices - - -def shuffle_weights( - *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) -) -> tuple[torch.Tensor, ...]: - """ - Applies shuffle_weight function from AITER to each - input tensor and returns them. - - Rearranges (shuffles) the input tensor/s - into a specified block layout for optimized computation. - - Args: - *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the block sizes used to divide - the tensors during shuffling. Default is (16, 16). - - Returns: - A Tuple of shuffled tensors. - """ - from aiter.ops.shuffle import shuffle_weight - - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a883ac81f41e..8cc374ac9155 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,18 +6,13 @@ import torch.nn as nn import torch.nn.functional as F -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_aiter_rmsnorm_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER def rms_norm( @@ -58,80 +53,34 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm_impl( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: - import aiter as rocm_aiter - - if x.dim() > 2: - x_original_shape = x.shape - x = x.reshape(-1, x_original_shape[-1]) - x = rocm_aiter.rms_norm(x, weight, variance_epsilon) - return x.reshape(x_original_shape) - - return rocm_aiter.rms_norm(x, weight, variance_epsilon) - + from vllm import _custom_ops as ops -def rocm_aiter_rmsnorm2d_fwd_with_add_impl( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - import aiter as rocm_aiter - - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) - rocm_aiter.rmsnorm2d_fwd_with_add( - output, # output - x, # input - residual, # residual input - residual_out, # residual output + out = torch.empty_like(x) + ops.poly_norm( + out, + x, weight, + bias, variance_epsilon, ) - return output, residual_out - - -def rocm_aiter_rms_norm_fake( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - return torch.empty_like(x) - - -def rocm_aiter_rmsnorm2d_fwd_with_add_fake( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_rms_norm", - op_func=rocm_aiter_rms_norm_impl, - fake_impl=rocm_aiter_rms_norm_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rmsnorm2d_fwd_with_add", - op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - ) + return out -def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): - use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ +def dispatch_rocm_rmsnorm_func( + with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False +): + use_aiter = use_aiter and dtype in [ torch.float16, torch.bfloat16, ] if use_aiter and with_fused_add: - return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + return rocm_aiter_ops.rms_norm2d_with_add if use_aiter: - return torch.ops.vllm.rocm_aiter_rms_norm + return rocm_aiter_ops.rms_norm # fall back to CUDA implementation if with_fused_add: @@ -169,11 +118,14 @@ def __init__( self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): + aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype + with_fused_add=False, + dtype=weight_dtype, + use_aiter=aiter_rmsnorm_enabled, ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype + with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..d32ae6674ee6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -582,11 +583,8 @@ def __init__( # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( @@ -829,12 +827,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..6da136cbc8f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,12 +7,12 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() if self.weight_block_size is not None: assert not self.is_static_input_scheme diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..e4e1cbff712f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -56,7 +57,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -369,7 +369,7 @@ def __init__(self, quant_config: Fp8Config): if vllm_is_batch_invariant(): self.use_marlin = False - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size @@ -869,12 +869,8 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - shuffle_weights, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -916,7 +912,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -962,7 +958,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1042,7 +1038,7 @@ def process_weights_after_loading(self, layer: Module) -> None: start += shard_size if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..f5cd91469b78 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -4,54 +4,14 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig -def rocm_aiter_gemm_w8a8_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - from aiter import gemm_a8w8_CK - - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects - # a to be [M, K] - # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) - - -def rocm_aiter_gemm_w8a8_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8", - op_func=rocm_aiter_gemm_w8a8_impl, - fake_impl=rocm_aiter_gemm_w8a8_fake, - ) - - class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -75,7 +35,7 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + "installed on ROCm.", ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + if not (rocm_aiter_ops.is_linear_enabled()): return ( False, "AiterScaledMMLinearKernel is disabled. " @@ -157,6 +117,4 @@ def apply_weights( # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8( - x_q, w_q.t(), x_s, w_s, bias, out_dtype - ) + return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8825611051e5..dc215258790e 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -21,10 +22,6 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - use_mxfp4_aiter_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) @@ -122,7 +119,7 @@ def __init__( if current_platform.is_rocm(): self.use_marlin = False - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def create_weights( self, @@ -309,12 +306,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -469,13 +462,15 @@ def __init__( "not implemented. Please open an issue." ) + self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() + self.emulate = not current_platform.supports_mx() or not ( - use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " @@ -644,28 +639,18 @@ def apply( ) if not self.emulate: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - aiter_acts = { - ActivationType.No.name.lower(): ActivationType.No, - ActivationType.Silu.name.lower(): ActivationType.Silu, - ActivationType.Gelu.name.lower(): ActivationType.Gelu, - } - assert activation in aiter_acts, ( - f"Aiter CK fp4 MoE doesn't support activation {activation}" + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, ) - out = fused_moe( + + out = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=aiter_acts[activation], - doweight_stage1=False, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + quant_config=self.moe_quant_config, ) else: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c25c522dea55..007e78e68d5c 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -31,6 +31,13 @@ logger = init_logger(__name__) +# TODO: move registration of custom op to aiter_ops.py +# `from vllm._aiter_ops import rocm_aiter_ops` +# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` +# for envs checks which does not require @cache anymore. +# triton kernel is torch compile compatible. +# does not require direct registeration. +# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: return ( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7fecda2166ef..63726c07b7d1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -68,78 +69,6 @@ def cutlass_scaled_mm( ) -def rocm_aiter_gemm_w8a8_blockscale_impl( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - def is_aiter_triton_kernel_tuned(n, k): - return (n, k) in [ - (1024, 8192), - (2112, 7168), - (3072, 1536), - (32768, 8192), - (4096, 7168), - (4608, 7168), - (512, 7168), - (7168, 2048), - (7168, 256), - (8192, 1024), - (8192, 32768), - ] - - n, k = weight.shape - if input_scale is not None: - q_input = input_2d - elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k): - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - # MI350 case uses triton kernel - q_input, input_scale = per_token_group_quant_fp8( - input_2d, - group_size, - column_major_scales=False, - use_ue8m0=False, - ) - else: - # MI300 uses tuned AITER ASM/C++ kernel - import aiter as rocm_aiter - from aiter import gemm_a8w8_blockscale, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 - ) - - return gemm_a8w8_blockscale( - q_input, weight, input_scale, weight_scale, dtype=output_dtype - ) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = input_2d.shape[0] - n = weight.shape[0] - return torch.empty(m, n, dtype=output_dtype, device=input_2d.device) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - ) - - # TODO we should be able to change the type of block_size to GroupShape # after we resolve GroupShape compilation issue # https://github.com/vllm-project/vllm/issues/25270 @@ -385,14 +314,40 @@ def _run_aiter( input_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - input_2d, - weight, - input_scale, - weight_scale, - self.act_quant_group_shape.col, - input_2d.dtype, - ) + + n, k = weight.shape + if input_scale is not None: + q_input = input_2d + + # MI350 case uses triton kernel + if ( + not current_platform.is_fp8_fnuz() + and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) + ): + q_input, input_scale = per_token_group_quant_fp8( + input_2d, + self.act_quant_group_shape.col, + column_major_scales=False, + use_ue8m0=False, + ) + return rocm_aiter_ops.triton_gemm_a8w8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) + + # MI300 uses tuned AITER ASM/C++ kernel + else: + q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + return rocm_aiter_ops.gemm_w8a8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) def _run_triton( self, @@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) -def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm for MI3XX""" - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - ) - - def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e86435..7fe902807a74 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -472,7 +472,7 @@ def apply( # Example: # When the number of token is 1, per-token scale is [[1]] # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 91276320df4d..2ef54e75df44 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,13 +4,10 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch -from .rocm_aiter_rope_ops import ( - is_rocm_triton_rotary_embedding_enabled, - rocm_aiter_rotary_emb, -) @CustomOp.register("rotary_embedding") @@ -48,8 +45,8 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embedding_enabled = ( - is_rocm_triton_rotary_embedding_enabled() + self.is_rocm_triton_rotary_embed_enabled = ( + rocm_aiter_ops.is_triton_rotary_embed_enabled() ) def _compute_inv_freq(self, base: float) -> torch.Tensor: @@ -169,9 +166,9 @@ def forward_hip( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.is_rocm_triton_rotary_embedding_enabled: + if self.is_rocm_triton_rotary_embed_enabled: self._match_cos_sin_cache_dtype(query) - rocm_aiter_rotary_emb( + rocm_aiter_ops.triton_rotary_embed( positions, query, key, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index d9134f05fddf..e72834e473c1 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -146,6 +146,15 @@ def forward_native( key = key_rot return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(positions, query, key, offsets) + def forward_cuda( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index a01d14f7b3a1..000000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_triton_rotary_embedding_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_TRITON_ROPE - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter.ops.triton.rope as ops - - ops.rope_cached_thd_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_triton_rotary_embedding_enabled(): - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_triton", - op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, - dispatch_key=current_platform.dispatch_key, - ) - - -def rocm_aiter_rotary_emb( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - cos_sin_cache: torch.Tensor, - head_size: int, - rotary_dim: int, - is_neox_style: bool, -): - num_tokens = positions.numel() - cos, sin = cos_sin_cache.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - rotate_style = 0 if is_neox_style else 1 - - query = query.view(num_tokens, -1, head_size) - key = key.view(num_tokens, -1, head_size) - query_ = query[..., :rotary_dim] - key_ = key[..., :rotary_dim] - positions = positions.view(*query.shape[:1]) - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( - positions, - sin, - cos, - query_, - key_, - rotate_style, - False, - ) - query = query.view(query_shape) - key = key.view(key_shape) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 63eaf63cc3c4..38189e17f7d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton @@ -50,10 +51,6 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -294,10 +291,8 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - if ( - config.n_shared_experts is None - or is_rocm_aiter_fusion_shared_expert_enabled() - ): + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -330,14 +325,14 @@ def __init__( # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 - if not is_rocm_aiter_moe_enabled() + if not self.is_rocm_aiter_moe_enabled else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, n_shared_experts=config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() else None, ) @@ -371,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - if not is_rocm_aiter_moe_enabled(): + if not self.is_rocm_aiter_moe_enabled: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None @@ -1428,6 +1423,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1456,7 +1454,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_moe_shared_expert_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, @@ -1472,9 +1470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if spec_layer is not None: continue # skip spec decode layers for main model - is_fuse_shared_experts_layer = ( - is_rocm_aiter_fusion_shared_expert_enabled() - and ("mlp.shared_experts" in name) + is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( + "mlp.shared_experts" in name ) for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1abd6300036d..e6536a02a73d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: + from vllm._aiter_ops import rocm_aiter_ops + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -202,12 +204,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: from importlib.util import find_spec + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: @@ -228,19 +233,23 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") - if use_mla: - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( - is_aiter_mla_enabled, + + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) + if use_mla: if selected_backend is None: selected_backend = ( _Backend.ROCM_AITER_MLA - if is_aiter_mla_enabled() or block_size == 1 + if rocm_aiter_ops.is_mla_enabled() or block_size == 1 else _Backend.TRITON_MLA ) @@ -265,12 +274,12 @@ def get_attn_backend_cls( logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + rocm_aiter_ops.is_mha_enabled() ) or selected_backend == _Backend.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + rocm_aiter_ops.is_triton_unified_attn_enabled() ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") return ( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 6c8145b6847d..102c4a288d13 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -198,6 +198,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, @@ -271,28 +272,15 @@ class QueryLenSupport(Enum): flashinfer_available = False -def is_rocm_aiter_fp8bmm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP8BMM - and envs.VLLM_ROCM_USE_AITER - ) - - -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 - ) - - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn - ): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() logger = init_logger(__name__) @@ -1114,6 +1102,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1163,7 +1152,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1192,7 +1181,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1201,7 +1190,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1213,10 +1202,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( + x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) # Convert from (B, N, V) to (B, N * V) @@ -1576,7 +1564,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1605,7 +1593,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1614,7 +1602,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1963,7 +1951,6 @@ def forward( # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) @@ -1971,9 +1958,9 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm( + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, self.W_K, self.W_K_scale, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 71eac84b6f06..87733c6af05c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,9 +6,8 @@ import torch -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import ( @@ -22,10 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA - - class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -288,7 +283,7 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd( + rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o,