@@ -456,6 +456,19 @@ def stateless_init_device_torch_dist_pg(
456456 def device_count (cls ) -> int :
457457 return cuda_device_count_stateless ()
458458
459+ @classmethod
460+ def is_kv_cache_dtype_supported (cls , kv_cache_dtype : str ) -> bool :
461+ fp8_attention = kv_cache_dtype .startswith ("fp8" )
462+ will_use_fa = (not envs .is_set ("VLLM_ATTENTION_BACKEND" )
463+ ) or envs .VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
464+ supported = False
465+ if cls .is_device_capability (100 ):
466+ supported = True
467+ elif fp8_attention and will_use_fa :
468+ from vllm .attention .utils .fa_utils import flash_attn_supports_fp8
469+ supported = flash_attn_supports_fp8 ()
470+ return supported
471+
459472
460473# NVML utils
461474# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
@@ -583,19 +596,6 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
583596 " not found. Assuming no NVLink available." )
584597 return False
585598
586- @classmethod
587- def is_kv_cache_dtype_supported (cls , kv_cache_dtype : str ) -> bool :
588- fp8_attention = kv_cache_dtype .startswith ("fp8" )
589- will_use_fa = (not envs .is_set ("VLLM_ATTENTION_BACKEND" )
590- ) or envs .VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
591- supported = False
592- if cls .is_device_capability (100 ):
593- supported = True
594- elif fp8_attention and will_use_fa :
595- from vllm .attention .utils .fa_utils import flash_attn_supports_fp8
596- supported = flash_attn_supports_fp8 ()
597- return supported
598-
599599
600600# Autodetect either NVML-enabled or non-NVML platform
601601# based on whether NVML is available.
0 commit comments