diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 0d7d4b694f07..2e31b7bad747 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -148,6 +148,31 @@ def has_nvidia_artifactory() -> bool: return False +@functools.cache +def supports_trtllm_attention() -> tuple[bool, Optional[str]]: + """Cache result which only depends on the environment""" + # This is a lambda, call it once + env_value = envs.VLLM_USE_TRTLLM_ATTENTION + + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins + if not (current_platform.is_device_capability(100) + and has_nvidia_artifactory()): + return False, env_value + + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + use_trtllm = (env_value == "1") + if use_trtllm: + logger.info_once("Using TRTLLM attention.") + return use_trtllm, env_value + + return True, None + + def use_trtllm_attention( num_tokens: int, max_seq_len: int, @@ -157,9 +182,8 @@ def use_trtllm_attention( attn_head_size: Optional[int], has_sinks: bool = False, ) -> bool: - # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - if not (current_platform.is_device_capability(100) - and has_nvidia_artifactory()): + use_trtllm, env_value = supports_trtllm_attention() + if not use_trtllm: return False # Check if the dimensions are supported by TRTLLM decode attention @@ -174,18 +198,7 @@ def use_trtllm_attention( "Using TRTLLM attention (required for attention sinks).") return True - env_value = envs.VLLM_USE_TRTLLM_ATTENTION - if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - use_trtllm = (env_value == "1") - if use_trtllm: - logger.info_once("Using TRTLLM attention.") - return use_trtllm - else: + if env_value is None: # Environment variable not set - use auto-detection use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 and kv_cache_dtype == "auto") @@ -193,6 +206,9 @@ def use_trtllm_attention( logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm + # Environment variable is set to 1 - respect it + return True + if has_flashinfer(): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1c7d08798964..5e6bc331835b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -248,19 +248,23 @@ def use_cascade_attention( @functools.lru_cache def get_kv_cache_layout(): + # Format specified by the code. global _KV_CACHE_LAYOUT_OVERRIDE - # Override with format specified by the user. + + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + cache_layout = _KV_CACHE_LAYOUT_OVERRIDE + logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ + "Setting KV cache layout to %s.", cache_layout) + return cache_layout + + # Format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT + # When neither the user nor the override specified a layout, get default if cache_layout is None: - if envs.VLLM_USE_TRTLLM_ATTENTION: - cache_layout = "HND" - else: - cache_layout = get_kv_connector_cache_layout() + cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) - if _KV_CACHE_LAYOUT_OVERRIDE is not None: - cache_layout = _KV_CACHE_LAYOUT_OVERRIDE return cache_layout