Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_aiter_rmsnorm_enabled() -> bool:
Expand All @@ -17,6 +18,7 @@ def is_rocm_aiter_rmsnorm_enabled() -> bool:
and envs.VLLM_ROCM_USE_AITER


# Non-AITER version
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
Expand All @@ -30,6 +32,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
return out


# Non-AITER version
def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -43,9 +46,9 @@ def fused_add_rms_norm(
return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

# AITER version
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
Expand All @@ -56,7 +59,22 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(input)


direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)


# AITER version
def rocm_aiter_fused_add_rms_norm_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:

Expand All @@ -75,14 +93,29 @@ def rocm_aiter_fused_add_rms_norm(
return output, residual_out


def rocm_aiter_fused_add_rms_norm_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)


direct_register_custom_op(
op_name="rocm_aiter_fused_add_rms_norm",
op_func=rocm_aiter_fused_add_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_add_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)


def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return torch.ops.vllm.rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm

if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
return torch.ops.vllm.rocm_aiter_rms_norm
return rms_norm


Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def get_hf_processor(
kwargs["fps"] = fps
processor = self.ctx.get_hf_processor(
Qwen2_5OmniProcessor,
image_processor=self.get_image_processor(
min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get("use_fast", True)),
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs,
)
if not hasattr(processor, "audio_token"):
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,11 +794,11 @@ def get_hf_processor(

return self.ctx.get_hf_processor(
Qwen2_5_VLProcessor,
image_processor=self.get_image_processor(
min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get("use_fast", True)),
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs,
)

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,11 +759,11 @@ def get_hf_processor(
) -> Qwen2VLProcessor:
return self.ctx.get_hf_processor(
Qwen2VLProcessor,
image_processor=self.get_image_processor(
min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get("use_fast", True)),
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs,
)

Expand Down
12 changes: 2 additions & 10 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,8 @@ def _forward_decode(

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
max_seqlen_qo = attn_metadata.prefill.max_query_len

# max_seqlen_qo must be 1 except for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
Expand Down