Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1df7ce2
support download sgl-flash-attn from kernels community
rainj-me Mar 17, 2026
ac55c34
fix the function call error
rainj-me Mar 17, 2026
eae7f75
move all sgl-kernel call to kernels community
rainj-me Mar 17, 2026
d0537b5
add fallback to sgl-kernel solution since the arm is not supported
rainj-me Mar 17, 2026
8ebdc72
remove the kernel lock file from repo and mve the docker and ci kerne…
rainj-me Mar 18, 2026
3a9fcdc
Update flash_attention_v3.py
rainj-me Mar 20, 2026
a3f572a
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 20, 2026
8e8eeff
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 23, 2026
fe7e49e
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 25, 2026
4f5c488
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 25, 2026
729469d
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 26, 2026
0498ae9
refactor the fa4 and fa4
rainj-me Mar 31, 2026
c19fdad
fix the unit test for fa3
rainj-me Mar 31, 2026
cac1f74
Merge branch 'main' into kernels-community-fa3
rainj-me Mar 31, 2026
7553c4d
address comments
rainj-me Apr 1, 2026
0cb1af5
fix ci failure with NoneType has no view
rainj-me Apr 1, 2026
e8f4d32
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 1, 2026
86ae3ab
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 2, 2026
e28a8e1
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 2, 2026
b0b52cc
make the fa3 kernels load from sgl-kernel by default
rainj-me Apr 3, 2026
241f205
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
88d0075
fix unsuccessfully resolve
rainj-me Apr 3, 2026
4796bfc
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
6fb65b8
refactor hardcode SGLANG_CACHE_DIR env variable
rainj-me Apr 3, 2026
f622af8
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
65b80ca
fix unsuccessfully resolve
rainj-me Apr 3, 2026
654bd22
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
a9da0ef
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
510efea
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 3, 2026
4eabccd
Merge branch 'main' into kernels-community-fa3
rainj-me Apr 6, 2026
8abedb7
move the kernels imports to try cache block
rainj-me Apr 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ inputs/

# setuptools-scm generated version file
python/sglang/_version.py
python/kernel.lock

# MUSA section
# Generated source files by torchada
Expand Down
7 changes: 7 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \
fi \
&& FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin
&& kernels download python
&& kernels lock python
&& mv python/kernels.lock /root/.cache/sglang

# DeepEP
# We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2
Expand Down Expand Up @@ -561,6 +564,10 @@ COPY --from=framework /usr/local/lib/python3.12/dist-packages /usr/local/lib/pyt
# Copy SGLang workspace
COPY --from=framework /sgl-workspace /sgl-workspace

# Copy cache for kernels from kernels community
COPY --from=framework /root/.cache/huggingface /root/.cache/huggingface
COPY --from=framework /root/.cache/sglang /root/.cache/sglang

# Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only)
RUN if [ "${CUDA_VERSION%%.*}" = "13" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \
rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \
Expand Down
2 changes: 2 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) |
| `SGLANG_REQ_WAITING_TIMEOUT` | Timeout (in seconds) for requests waiting in the queue before being scheduled | `-1` |
| `SGLANG_REQ_RUNNING_TIMEOUT` | Timeout (in seconds) for requests running in the decode batch | `-1` |
| `SGLANG_CACHE_DIR` | Cache directory for model weights and other data | `~/.cache/sglang` |

## Performance Tuning

Expand Down Expand Up @@ -47,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` |
| `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` |
| `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` |
| `SGLANG_USE_SGL_FA3_KERNEL` | Use sgl-kernel implementation for FlashAttention v3 | `true` |


## DeepGEMM Configuration (Advanced Optimization)
Expand Down
4 changes: 4 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ dependencies = [
"watchfiles",
"xgrammar==0.1.32",
"smg-grpc-servicer>=0.5.0",
"kernels",
]

[[tool.uv.index]]
Expand Down Expand Up @@ -201,3 +202,6 @@ version_file = "sglang/_version.py"
git_describe_command = ["python3", "python/tools/get_version_tag.py", "--tag-only"]
# Allow editable installs even when .git metadata is not available.
fallback_version = "0.0.0.dev0"

[tool.kernels.dependencies]
"kernels-community/sgl-flash-attn3" = 1
286 changes: 286 additions & 0 deletions python/sglang/jit_kernel/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from typing import Optional, Union

import torch

from .flash_attention_v3 import flash_attn_varlen_func as fa3_flash_attn_varlen_func
from .flash_attention_v3 import flash_attn_with_kvcache as fa3_flash_attn_with_kvcache
from .flash_attention_v4 import flash_attn_varlen_func as fa4_flash_attn_varlen_func
from .flash_attention_v4 import flash_attn_with_kvcache as fa4_flash_attn_with_kvcache


def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
score_mod=None,
aux_tensors=None,
ver=3,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.

If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).

See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.

Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.

If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Note: Does not support backward pass.

Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
score_mod [optional]: A callable that takes the attention scores and applies a modification.
aux_tensors [optional]: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.

Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""

if ver == 3:
return fa3_flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=k,
v=v,
qv=qv,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k_new,
max_seqlen_q=max_seqlen_q,
rotary_seqlens=rotary_seqlens,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
return_softmax_lse=return_softmax_lse,
sinks=sinks,
)
elif ver == 4:
return fa4_flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=k,
v=v,
qv=qv,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
rotary_seqlens=rotary_seqlens,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
sinks=sinks,
score_mod=score_mod,
aux_tensors=aux_tensors,
return_softmax_lse=return_softmax_lse,
)
else:
raise RuntimeError(f"Unknown flash attention version {ver}")


def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
seqused_q=None,
seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
score_mod=None,
aux_tensors=None,
ver=3,
):

if ver == 3:
return fa3_flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
page_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
qv=qv,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
return_softmax_lse=return_softmax_lse,
sinks=sinks,
)
elif ver == 4:
return fa4_flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
page_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size=window_size,
sinks=sinks,
num_splits=num_splits,
pack_gqa=pack_gqa,
score_mod=score_mod,
aux_tensors=aux_tensors,
return_softmax_lse=return_softmax_lse,
)
else:
raise RuntimeError(f"Unknown flash attention version {ver}")
Loading
Loading