diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 5d0343ffd607..258f173da3a4 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -16,9 +16,10 @@ from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton if current_platform.is_rocm(): - from vllm.platforms.rocm import _ON_GFX942 + from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950 else: _ON_GFX942 = False + _ON_GFX950 = False @triton.jit @@ -385,7 +386,7 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = paged_mqa_logits_module() if aiter_paged_mqa_logits_module is not None: - if _ON_GFX942: + if _ON_GFX942 or _ON_GFX950: deepgemm_fp8_paged_mqa_logits = ( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits )