diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ab0cf2051f87..1d707d56daba 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -283,11 +283,18 @@ def use_trtllm_attention( if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = ( - num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" - ) - if use_trtllm: - logger.warning_once("Using TRTLLM attention (auto-detected).") + if is_prefill: + # Prefill auto-detection + use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + if use_trtllm: + logger.warning_once("Using TRTLLM prefill attention (auto-detected).") + else: + # Decode auto-detection + use_trtllm = ( + num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" + ) + if use_trtllm: + logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it