Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype)
Expand Down Expand Up @@ -1700,7 +1704,7 @@ def gemm_a8wfp4(
)

@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
def triton_fp4_gemm_dynamic_quant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def __init__(
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

from collections.abc import Callable
from fractions import Fraction
from functools import cache, partial
from functools import partial
from typing import Any

import torch
import torch.nn.functional as F

from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
Expand Down Expand Up @@ -37,22 +36,6 @@
logger = init_logger(__name__)


# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
and envs.VLLM_ROCM_USE_AITER
)


try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import (
Expand All @@ -63,7 +46,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:

from vllm.utils.torch_utils import direct_register_custom_op

if is_rocm_aiter_fp4_asm_gemm_enabled():
if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

def gemm_with_dynamic_quant(
Expand Down Expand Up @@ -233,7 +216,9 @@ def __init__(
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
)

self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
self.rocm_use_aiter_fp4_asm_gemm = (
rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()
)

if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def cp_mha_gather_cache(
total_tokens: int,
):
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
"kv_cache_layout only support NHD, SHUFFLE"
"kv_cache_layout only supports NHD, SHUFFLE"
)
head_dim = key.shape[2]
x = 16 // key_cache.element_size()
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
# For k cache layout: [num_blocks, page_size, num_heads, head_dim]
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
Expand Down Expand Up @@ -832,7 +832,7 @@ def __init__(

if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for FlashAttentionImpl"
"Encoder self-attention is not implemented for AiterFlashAttentionImpl"
)

def extend_for_sliding_window(
Expand Down Expand Up @@ -1047,7 +1047,8 @@ def forward(

if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
"fused output quantization is not yet supported "
"for AiterFlashAttentionImpl"
)

if attn_metadata is None:
Expand Down
Loading