Skip to content

Commit 2dec7c1

Browse files
authored
[Bugfix][CUDA] fixes CUDA FP8 kv cache dtype supported (#21420)
Signed-off-by: elvischenv <[email protected]>
1 parent 08d2bd7 commit 2dec7c1

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

vllm/platforms/cuda.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,19 @@ def stateless_init_device_torch_dist_pg(
456456
def device_count(cls) -> int:
457457
return cuda_device_count_stateless()
458458

459+
@classmethod
460+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
461+
fp8_attention = kv_cache_dtype.startswith("fp8")
462+
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
463+
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
464+
supported = False
465+
if cls.is_device_capability(100):
466+
supported = True
467+
elif fp8_attention and will_use_fa:
468+
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
469+
supported = flash_attn_supports_fp8()
470+
return supported
471+
459472

460473
# NVML utils
461474
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
@@ -583,19 +596,6 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
583596
" not found. Assuming no NVLink available.")
584597
return False
585598

586-
@classmethod
587-
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
588-
fp8_attention = kv_cache_dtype.startswith("fp8")
589-
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
590-
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
591-
supported = False
592-
if cls.is_device_capability(100):
593-
supported = True
594-
elif fp8_attention and will_use_fa:
595-
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
596-
supported = flash_attn_supports_fp8()
597-
return supported
598-
599599

600600
# Autodetect either NVML-enabled or non-NVML platform
601601
# based on whether NVML is available.

0 commit comments

Comments
 (0)