Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
Expand All @@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)


Expand Down Expand Up @@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
pass


def _rocm_aiter_fused_topk_impl(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.fused_moe import fused_topk

# fused_topk returns (topk_weights, topk_indices)
return fused_topk(x, router_logits, top_k, gate_up)


def _rocm_aiter_fused_topk_fake(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> None:
# tuple[torch.Tensor, torch.Tensor]:
pass


# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None

Expand Down Expand Up @@ -994,6 +1024,70 @@ def refresh_env_variables(cls):
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM

@staticmethod
def get_aiter_activation_type(activation_str: str):
"""
Given an activation type as a string, returns the corresponding aiter ActivationType enum.
Supported activation types: "no", "none", "silu", "gelu", "swiglu".
Returns None if the mapping fails.

Args:
activation_str (str): Activation type as string.

Returns:
Aiter ActivationType enum value, or None if not found.
"""
# Import only locally, since aiter may not always be available.
try:
from aiter import ActivationType
except ImportError:
return None

if not isinstance(activation_str, str):
return None

name = activation_str.strip().lower()
mapping = {
"none": ActivationType.No,
"no": ActivationType.No,
"silu": ActivationType.Silu,
"gelu": ActivationType.Gelu,
"swiglu": ActivationType.Swiglu,
}
return mapping.get(name)

@staticmethod
def get_aiter_quant_type(quant_type_str: str):
"""
Given a quantization type as a string, returns the corresponding aiter QuantType enum.
Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
Returns None if the mapping fails.

Args:
quant_type_str (str): Quantization type as string.

Returns:
Aiter QuantType enum value, or None if not found.
"""
try:
from aiter import QuantType
except ImportError:
return None

if not isinstance(quant_type_str, str):
return None

name = quant_type_str.strip().lower()
mapping = {
"no": QuantType.No,
"per_tensor": QuantType.per_Tensor,
"per_token": QuantType.per_Token,
"per_1x32": QuantType.per_1x32,
"per_1x128": QuantType.per_1x128,
"per_128x128": QuantType.per_128x128,
}
return mapping.get(name)

@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
Expand Down Expand Up @@ -1127,6 +1221,14 @@ def register_ops_once() -> None:
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_topk",
op_func=_rocm_aiter_fused_topk_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_topk_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
Expand Down Expand Up @@ -1360,6 +1462,10 @@ def fused_moe(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
Expand All @@ -1377,6 +1483,10 @@ def fused_moe(
a2_scale,
num_local_tokens,
output_dtype,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)

@staticmethod
Expand Down Expand Up @@ -1481,6 +1591,15 @@ def grouped_topk(
routed_scaling_factor,
)

@staticmethod
def fused_topk(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)

@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
Expand Down Expand Up @@ -1701,6 +1820,47 @@ def shuffle_weight(

return shuffle_weight(tensor, layout=layout)

@staticmethod
def shuffle_weight_a16w4(
tensor: "torch.Tensor",
nLane: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the weight tensor into (A16W4) layout for AITER kernels.

Args:
tensor: The input weight tensor to be shuffled.
layout: The block layout to use, defaults to (16, 4).

Returns:
torch.Tensor: The shuffled tensor.
"""
from aiter.ops.shuffle import shuffle_weight_a16w4

return shuffle_weight_a16w4(tensor, nLane, gate_up)

@staticmethod
def shuffle_scale_a16w4(
tensor: "torch.Tensor",
num_experts: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the scale tensor into (A16W4) layout for AITER kernels.

Args:
tensor: The input scale tensor to be shuffled.
num_experts: Number of experts, needed for reshaping logic.
gate_up: Whether the scale is for w13 (True) or w2 (False).

Returns:
torch.Tensor: The shuffled scale tensor.
"""
from aiter.ops.shuffle import shuffle_scale_a16w4

return shuffle_scale_a16w4(tensor, num_experts, gate_up)

@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
Expand Down
Loading