Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ def has_nvidia_artifactory() -> bool:
return False


@functools.cache
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
"""Cache result which only depends on the environment"""
# This is a lambda, call it once
env_value = envs.VLLM_USE_TRTLLM_ATTENTION

# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False, env_value

if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm, env_value

return True, None


def use_trtllm_attention(
num_tokens: int,
max_seq_len: int,
Expand All @@ -157,9 +182,8 @@ def use_trtllm_attention(
attn_head_size: Optional[int],
has_sinks: bool = False,
) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
use_trtllm, env_value = supports_trtllm_attention()
if not use_trtllm:
return False

# Check if the dimensions are supported by TRTLLM decode attention
Expand All @@ -174,25 +198,17 @@ def use_trtllm_attention(
"Using TRTLLM attention (required for attention sinks).")
return True

env_value = envs.VLLM_USE_TRTLLM_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
use_trtllm = (env_value == "1")
if use_trtllm:
logger.info_once("Using TRTLLM attention.")
return use_trtllm
else:
if env_value is None:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once("Using TRTLLM attention (auto-detected).")
return use_trtllm

# Environment variable is set to 1 - respect it
return True


if has_flashinfer():

Expand Down
18 changes: 11 additions & 7 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,23 @@ def use_cascade_attention(

@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE
# Override with format specified by the user.

if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
"Setting KV cache layout to %s.", cache_layout)
return cache_layout

# Format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
# When neither the user nor the override specified a layout, get default
if cache_layout is None:
if envs.VLLM_USE_TRTLLM_ATTENTION:
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
return cache_layout


Expand Down