-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[TurboQuant] enable FA3/FA4 for prefill paths #40092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
839d499
7b3b90f
13d9523
e65bf3a
a74b840
fefea99
dce084e
4ea1078
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,7 @@ | |
| MultipleOf, | ||
| ) | ||
| from vllm.v1.attention.backends.fa_utils import ( | ||
| get_flash_attn_version, | ||
| is_flash_attn_varlen_func_available, | ||
| ) | ||
| from vllm.v1.attention.backends.utils import split_decodes_and_prefills | ||
|
|
@@ -271,13 +272,53 @@ def __init__( | |
| self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8) | ||
| self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1 | ||
|
|
||
| # Detect flash-attn version (FA2/3/4) for prefill paths. | ||
| self.fa_version = get_flash_attn_version(head_size=head_size) | ||
|
Comment on lines
+275
to
+276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This new FA-version selection path only calls Useful? React with 👍 / 👎. |
||
|
|
||
| # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph, | ||
| # and benchmarks show no regression vs dynamic in eager mode). | ||
| vllm_config = get_current_vllm_config() | ||
| self.max_num_kv_splits = ( | ||
| vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph | ||
| ) | ||
|
|
||
| def _flash_attn_varlen( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_q: torch.Tensor, | ||
| cu_seqlens_k: torch.Tensor, | ||
| max_seqlen_q: int, | ||
| max_seqlen_k: int, | ||
| ) -> torch.Tensor: | ||
| # fa_utils.get_flash_attn_version() returns None on backends that | ||
| # should not pass an explicit fa_version kwarg. | ||
| if self.fa_version is None: | ||
| return flash_attn_varlen_func( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| ) | ||
| return flash_attn_varlen_func( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| fa_version=self.fa_version, | ||
| ) | ||
|
|
||
| def _ensure_on_device(self, layer, device): | ||
| """One-time derivation of TQ buffers (rotation matrix, midpoints). | ||
|
|
||
|
|
@@ -503,16 +544,14 @@ def _prefill_attention( | |
| # max_query_len == max_seq_len means no request has prior cached KV. | ||
| # Both are Python ints — no GPU sync. | ||
| if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: | ||
| return flash_attn_varlen_func( | ||
| return self._flash_attn_varlen( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| cu_seqlens_q=attn_metadata.query_start_loc, | ||
| cu_seqlens_k=attn_metadata.query_start_loc, | ||
| max_seqlen_q=attn_metadata.max_query_len, | ||
| max_seqlen_k=attn_metadata.max_query_len, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| ) | ||
|
|
||
| # Continuation or no flash_attn: per-request attention. | ||
|
|
@@ -552,16 +591,14 @@ def _prefill_attention( | |
| if _HAS_FLASH_ATTN: | ||
| _cu_2[1] = q_len | ||
| cu = _cu_2 | ||
| out = flash_attn_varlen_func( | ||
| out = self._flash_attn_varlen( | ||
| q=q_seq, | ||
| k=k_seq, | ||
| v=v_seq, | ||
| cu_seqlens_q=cu, | ||
| cu_seqlens_k=cu, | ||
| max_seqlen_q=q_len, | ||
| max_seqlen_k=q_len, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| ) | ||
| else: | ||
| q_t = q_seq.transpose(0, 1).contiguous() | ||
|
|
@@ -726,16 +763,14 @@ def _continuation_prefill( | |
| if _HAS_FLASH_ATTN: | ||
| cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) | ||
| cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) | ||
| return flash_attn_varlen_func( | ||
| return self._flash_attn_varlen( | ||
| q=query, | ||
| k=k_full, | ||
| v=v_full, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=q_len, | ||
| max_seqlen_k=seq_len, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| ) | ||
| else: | ||
| # SDPA fallback: expand KV for GQA, build causal mask | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to
get_flash_attn_versionshould include therequires_alibiargument. Passingrequires_alibi=alibi_slopes is not Noneensures that the backend correctly falls back to FlashAttention 2 if ALiBi slopes are present, as FA3 and FA4 do not currently support them. This maintains consistency with the version detection logic used inFlashAttentionImpl.