-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Kernels community fa3 #20796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Fridge003
merged 31 commits into
sgl-project:main
from
bytedance-iaas:kernels-community-fa3
Apr 7, 2026
+1,956
−61
Merged
Kernels community fa3 #20796
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
1df7ce2
support download sgl-flash-attn from kernels community
rainj-me ac55c34
fix the function call error
rainj-me eae7f75
move all sgl-kernel call to kernels community
rainj-me d0537b5
add fallback to sgl-kernel solution since the arm is not supported
rainj-me 8ebdc72
remove the kernel lock file from repo and mve the docker and ci kerne…
rainj-me 3a9fcdc
Update flash_attention_v3.py
rainj-me a3f572a
Merge branch 'main' into kernels-community-fa3
rainj-me 8e8eeff
Merge branch 'main' into kernels-community-fa3
rainj-me fe7e49e
Merge branch 'main' into kernels-community-fa3
rainj-me 4f5c488
Merge branch 'main' into kernels-community-fa3
rainj-me 729469d
Merge branch 'main' into kernels-community-fa3
rainj-me 0498ae9
refactor the fa4 and fa4
rainj-me c19fdad
fix the unit test for fa3
rainj-me cac1f74
Merge branch 'main' into kernels-community-fa3
rainj-me 7553c4d
address comments
rainj-me 0cb1af5
fix ci failure with NoneType has no view
rainj-me e8f4d32
Merge branch 'main' into kernels-community-fa3
rainj-me 86ae3ab
Merge branch 'main' into kernels-community-fa3
rainj-me e28a8e1
Merge branch 'main' into kernels-community-fa3
rainj-me b0b52cc
make the fa3 kernels load from sgl-kernel by default
rainj-me 241f205
Merge branch 'main' into kernels-community-fa3
rainj-me 88d0075
fix unsuccessfully resolve
rainj-me 4796bfc
Merge branch 'main' into kernels-community-fa3
rainj-me 6fb65b8
refactor hardcode SGLANG_CACHE_DIR env variable
rainj-me f622af8
Merge branch 'main' into kernels-community-fa3
rainj-me 65b80ca
fix unsuccessfully resolve
rainj-me 654bd22
Merge branch 'main' into kernels-community-fa3
rainj-me a9da0ef
Merge branch 'main' into kernels-community-fa3
rainj-me 510efea
Merge branch 'main' into kernels-community-fa3
rainj-me 4eabccd
Merge branch 'main' into kernels-community-fa3
rainj-me 8abedb7
move the kernels imports to try cache block
rainj-me File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,286 @@ | ||
| from typing import Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| from .flash_attention_v3 import flash_attn_varlen_func as fa3_flash_attn_varlen_func | ||
| from .flash_attention_v3 import flash_attn_with_kvcache as fa3_flash_attn_with_kvcache | ||
| from .flash_attention_v4 import flash_attn_varlen_func as fa4_flash_attn_varlen_func | ||
| from .flash_attention_v4 import flash_attn_with_kvcache as fa4_flash_attn_with_kvcache | ||
|
|
||
|
|
||
| def flash_attn_with_kvcache( | ||
| q, | ||
| k_cache, | ||
| v_cache, | ||
| k=None, | ||
| v=None, | ||
| qv=None, | ||
| rotary_cos=None, | ||
| rotary_sin=None, | ||
| cache_seqlens: Optional[Union[int, torch.Tensor]] = None, | ||
| cache_batch_idx: Optional[torch.Tensor] = None, | ||
| cache_leftpad: Optional[torch.Tensor] = None, | ||
| page_table: Optional[torch.Tensor] = None, | ||
| cu_seqlens_q: Optional[torch.Tensor] = None, | ||
| cu_seqlens_k_new: Optional[torch.Tensor] = None, | ||
| max_seqlen_q: Optional[int] = None, | ||
| rotary_seqlens: Optional[torch.Tensor] = None, | ||
| q_descale: Optional[torch.Tensor] = None, | ||
| k_descale: Optional[torch.Tensor] = None, | ||
| v_descale: Optional[torch.Tensor] = None, | ||
| softmax_scale=None, | ||
| causal=False, | ||
| window_size=(-1, -1), # -1 means infinite context window | ||
| attention_chunk: Optional[int] = None, | ||
| softcap=0.0, # 0.0 means deactivated | ||
| rotary_interleaved=True, | ||
| scheduler_metadata=None, | ||
| num_splits=0, # Can be tuned for speed | ||
| pack_gqa=None, # Can be tuned for speed | ||
| sm_margin=0, # Can be tuned if some SMs are used for communication | ||
| return_softmax_lse=False, | ||
| sinks=None, | ||
| score_mod=None, | ||
| aux_tensors=None, | ||
| ver=3, | ||
| ): | ||
| """ | ||
| If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from | ||
| k and v. This is useful for incremental decoding: you can pass in the cached keys/values from | ||
| the previous step, and update them with the new keys/values from the current step, and do | ||
| attention with the updated cache, all in 1 kernel. | ||
|
|
||
| If you pass in k / v, you must make sure that the cache is large enough to hold the new values. | ||
| For example, the KV cache could be pre-allocated with the max sequence length, and you can use | ||
| cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. | ||
|
|
||
| Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be | ||
| rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. | ||
| If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos | ||
| and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. | ||
| If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at | ||
| indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). | ||
|
|
||
| See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. | ||
|
|
||
| Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads | ||
| than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. | ||
| For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head | ||
| 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. | ||
|
|
||
| If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. | ||
| For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: | ||
| 1 1 1 1 0 | ||
| 1 1 1 1 1 | ||
| If seqlen_q = 5 and seqlen_k = 2, the causal mask is: | ||
| 0 0 | ||
| 0 0 | ||
| 0 0 | ||
| 1 0 | ||
| 1 1 | ||
| If the row of the mask is all zero, the output will be zero. | ||
|
|
||
| If window_size != (-1, -1), implements sliding window local attention. Query at position i | ||
| will only attend to keys between | ||
| [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. | ||
|
|
||
| Note: Does not support backward pass. | ||
|
|
||
| Arguments: | ||
| q: (batch_size, seqlen, nheads, headdim) | ||
| k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, | ||
| or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) | ||
| page_block_size must be a multiple of 256. | ||
| v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, | ||
| or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) | ||
| k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate | ||
| k with k_cache, starting at the indices specified by cache_seqlens. | ||
| v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. | ||
| qv [optional]: (batch_size, seqlen, nheads, headdim_v) | ||
| rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding | ||
| to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. | ||
| rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. | ||
| cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the | ||
| KV cache. | ||
| cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. | ||
| If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. | ||
| If the indices are not distinct, and k and v are provided, the values updated in the cache | ||
| might come from any of the duplicate indices. | ||
| cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. | ||
| page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. | ||
| softmax_scale: float. The scaling of QK^T before applying softmax. | ||
| Default to 1 / sqrt(headdim). | ||
| causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). | ||
| window_size: (left, right). If not (-1, -1), implements sliding window local attention. | ||
| attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. | ||
| softcap: float. Anything > 0 activates softcapping attention. | ||
| rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. | ||
| If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, | ||
| rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 | ||
| (i.e. GPT-NeoX style). | ||
| num_splits: int. If > 1, split the key/value into this many chunks along the sequence. | ||
| If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic | ||
| to automatically determine the number of splits. | ||
| Don't change this unless you know what you are doing. | ||
| return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. | ||
| score_mod [optional]: A callable that takes the attention scores and applies a modification. | ||
| aux_tensors [optional]: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. | ||
|
|
||
| Return: | ||
| out: (batch_size, seqlen, nheads, headdim). | ||
| softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The | ||
| logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax | ||
| normalization factor). | ||
| """ | ||
|
|
||
| if ver == 3: | ||
| return fa3_flash_attn_with_kvcache( | ||
| q, | ||
| k_cache, | ||
| v_cache, | ||
| k=k, | ||
| v=v, | ||
| qv=qv, | ||
| rotary_cos=rotary_cos, | ||
| rotary_sin=rotary_sin, | ||
| cache_seqlens=cache_seqlens, | ||
| cache_batch_idx=cache_batch_idx, | ||
| cache_leftpad=cache_leftpad, | ||
| page_table=page_table, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k_new=cu_seqlens_k_new, | ||
| max_seqlen_q=max_seqlen_q, | ||
| rotary_seqlens=rotary_seqlens, | ||
| q_descale=q_descale, | ||
| k_descale=k_descale, | ||
| v_descale=v_descale, | ||
| softmax_scale=softmax_scale, | ||
| causal=causal, | ||
| window_size=window_size, | ||
| attention_chunk=attention_chunk, | ||
| softcap=softcap, | ||
| rotary_interleaved=rotary_interleaved, | ||
| scheduler_metadata=scheduler_metadata, | ||
| num_splits=num_splits, | ||
| pack_gqa=pack_gqa, | ||
| sm_margin=sm_margin, | ||
| return_softmax_lse=return_softmax_lse, | ||
| sinks=sinks, | ||
| ) | ||
| elif ver == 4: | ||
| return fa4_flash_attn_with_kvcache( | ||
| q, | ||
| k_cache, | ||
| v_cache, | ||
| k=k, | ||
| v=v, | ||
| qv=qv, | ||
| rotary_cos=rotary_cos, | ||
| rotary_sin=rotary_sin, | ||
| cache_seqlens=cache_seqlens, | ||
| cache_batch_idx=cache_batch_idx, | ||
| cache_leftpad=cache_leftpad, | ||
| page_table=page_table, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| max_seqlen_q=max_seqlen_q, | ||
| rotary_seqlens=rotary_seqlens, | ||
| q_descale=q_descale, | ||
| k_descale=k_descale, | ||
| v_descale=v_descale, | ||
| softmax_scale=softmax_scale, | ||
| causal=causal, | ||
| window_size=window_size, | ||
| softcap=softcap, | ||
| num_splits=num_splits, | ||
| pack_gqa=pack_gqa, | ||
| sinks=sinks, | ||
| score_mod=score_mod, | ||
| aux_tensors=aux_tensors, | ||
| return_softmax_lse=return_softmax_lse, | ||
| ) | ||
| else: | ||
| raise RuntimeError(f"Unknown flash attention version {ver}") | ||
|
|
||
|
|
||
| def flash_attn_varlen_func( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| max_seqlen_q=None, | ||
| max_seqlen_k=None, | ||
| seqused_q=None, | ||
| seqused_k=None, | ||
| page_table=None, | ||
| softmax_scale=None, | ||
| causal=False, | ||
| qv=None, | ||
| q_descale=None, | ||
| k_descale=None, | ||
| v_descale=None, | ||
| window_size=(-1, -1), | ||
| attention_chunk=0, | ||
| softcap=0.0, | ||
| num_splits=1, | ||
| pack_gqa=None, | ||
| sm_margin=0, | ||
| return_softmax_lse=False, | ||
| sinks=None, | ||
| score_mod=None, | ||
| aux_tensors=None, | ||
| ver=3, | ||
| ): | ||
|
|
||
| if ver == 3: | ||
| return fa3_flash_attn_varlen_func( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| seqused_q=seqused_q, | ||
| seqused_k=seqused_k, | ||
| page_table=page_table, | ||
| softmax_scale=softmax_scale, | ||
| causal=causal, | ||
| qv=qv, | ||
| q_descale=q_descale, | ||
| k_descale=k_descale, | ||
| v_descale=v_descale, | ||
| window_size=window_size, | ||
| attention_chunk=attention_chunk, | ||
| softcap=softcap, | ||
| num_splits=num_splits, | ||
| pack_gqa=pack_gqa, | ||
| sm_margin=sm_margin, | ||
| return_softmax_lse=return_softmax_lse, | ||
| sinks=sinks, | ||
| ) | ||
| elif ver == 4: | ||
| return fa4_flash_attn_varlen_func( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| seqused_q=seqused_q, | ||
| seqused_k=seqused_k, | ||
| page_table=page_table, | ||
| softmax_scale=softmax_scale, | ||
| causal=causal, | ||
| softcap=softcap, | ||
| window_size=window_size, | ||
| sinks=sinks, | ||
| num_splits=num_splits, | ||
| pack_gqa=pack_gqa, | ||
| score_mod=score_mod, | ||
| aux_tensors=aux_tensors, | ||
| return_softmax_lse=return_softmax_lse, | ||
| ) | ||
| else: | ||
| raise RuntimeError(f"Unknown flash attention version {ver}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.