diff --git a/.gitignore b/.gitignore index b5917c299ecf..a8aa903e28f7 100644 --- a/.gitignore +++ b/.gitignore @@ -258,6 +258,7 @@ inputs/ # setuptools-scm generated version file python/sglang/_version.py +python/kernel.lock # MUSA section # Generated source files by torchada diff --git a/docker/Dockerfile b/docker/Dockerfile index d7f4ead4579c..57842c53564b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -219,6 +219,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install flashinfer-jit-cache==${FLASHINFER_VERSION} --index-url https://flashinfer.ai/whl/cu${CUINDEX} ; \ fi \ && FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin + && kernels download python + && kernels lock python + && mv python/kernels.lock /root/.cache/sglang # DeepEP # We use Tom's DeepEP fork for GB200 for now; the 1fd57b0276311d035d16176bb0076426166e52f3 commit is https://github.com/fzyzcjy/DeepEP/tree/gb200_blog_part_2 @@ -561,6 +564,10 @@ COPY --from=framework /usr/local/lib/python3.12/dist-packages /usr/local/lib/pyt # Copy SGLang workspace COPY --from=framework /sgl-workspace /sgl-workspace +# Copy cache for kernels from kernels community +COPY --from=framework /root/.cache/huggingface /root/.cache/huggingface +COPY --from=framework /root/.cache/sglang /root/.cache/sglang + # Fix Triton to use system ptxas for Blackwell (sm_103a) support (CUDA 13+ only) RUN if [ "${CUDA_VERSION%%.*}" = "13" ] && [ -d /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin ]; then \ rm -f /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas && \ diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index e2e93b177b9c..b7ac94a71245 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -19,6 +19,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) | | `SGLANG_REQ_WAITING_TIMEOUT` | Timeout (in seconds) for requests waiting in the queue before being scheduled | `-1` | | `SGLANG_REQ_RUNNING_TIMEOUT` | Timeout (in seconds) for requests running in the decode batch | `-1` | +| `SGLANG_CACHE_DIR` | Cache directory for model weights and other data | `~/.cache/sglang` | ## Performance Tuning @@ -47,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_CUSTOM_ALLREDUCE_ALGO` | The algorithm of custom all-reduce. Set to `oneshot` or `1stage` to force use one-shot. Set to `twoshot` or `2stage` to force use two-shot. | `` | | `SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM prefill attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | | `SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR` | Skip-softmax threshold scale factor for TRT-LLM decode attention in flashinfer. `None` means standard attention. See https://arxiv.org/abs/2512.12087 | `None` | +| `SGLANG_USE_SGL_FA3_KERNEL` | Use sgl-kernel implementation for FlashAttention v3 | `true` | ## DeepGEMM Configuration (Advanced Optimization) diff --git a/python/pyproject.toml b/python/pyproject.toml index 8e96b44afe3c..8faaa55f8fbd 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -77,6 +77,7 @@ dependencies = [ "watchfiles", "xgrammar==0.1.32", "smg-grpc-servicer>=0.5.0", + "kernels", ] [[tool.uv.index]] @@ -201,3 +202,6 @@ version_file = "sglang/_version.py" git_describe_command = ["python3", "python/tools/get_version_tag.py", "--tag-only"] # Allow editable installs even when .git metadata is not available. fallback_version = "0.0.0.dev0" + +[tool.kernels.dependencies] +"kernels-community/sgl-flash-attn3" = 1 diff --git a/python/sglang/jit_kernel/flash_attention.py b/python/sglang/jit_kernel/flash_attention.py new file mode 100644 index 000000000000..633863d0a648 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention.py @@ -0,0 +1,286 @@ +from typing import Optional, Union + +import torch + +from .flash_attention_v3 import flash_attn_varlen_func as fa3_flash_attn_varlen_func +from .flash_attention_v3 import flash_attn_with_kvcache as fa3_flash_attn_with_kvcache +from .flash_attention_v4 import flash_attn_varlen_func as fa4_flash_attn_varlen_func +from .flash_attention_v4 import flash_attn_with_kvcache as fa4_flash_attn_with_kvcache + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + score_mod [optional]: A callable that takes the attention scores and applies a modification. + aux_tensors [optional]: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + + if ver == 3: + return fa3_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + qv=qv, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sinks=sinks, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, + score_mod=None, + aux_tensors=None, + ver=3, +): + + if ver == 3: + return fa3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + return_softmax_lse=return_softmax_lse, + sinks=sinks, + ) + elif ver == 4: + return fa4_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap, + window_size=window_size, + sinks=sinks, + num_splits=num_splits, + pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, + return_softmax_lse=return_softmax_lse, + ) + else: + raise RuntimeError(f"Unknown flash attention version {ver}") diff --git a/python/sglang/jit_kernel/flash_attention_v3.py b/python/sglang/jit_kernel/flash_attention_v3.py new file mode 100644 index 000000000000..23018961d998 --- /dev/null +++ b/python/sglang/jit_kernel/flash_attention_v3.py @@ -0,0 +1,222 @@ +import logging +import os +from typing import Optional, Union + +import torch + +from sglang.jit_kernel.utils import cache_once +from sglang.kernel_api_logging import debug_kernel_api +from sglang.srt.environ import envs + +logger = logging.getLogger(__name__) + +SGL_FA3_KERNEL_REPO = "kernels-community/sgl-flash-attn3" +SGL_FA3_KERNEL_REVISION = "v1" +DEFAULT_FA3_KERNEL_LOCKFILE = "kernels.lock" + + +@cache_once +def _load_fa3_kernels(): + # By default, we use the implementation from sgl-kernel, + # which is expected to be more stable and compatible + if envs.SGLANG_USE_SGL_FA3_KERNEL.get(): + logger.debug( + f"SGLANG_USE_SGL_FA3_KERNEL=True, use sgl-kernel implementation for FlashAttention v3 " + ) + return _load_fa3_kernel_from_sgl() + + # Otherwise, we try to load the kernels from the kernels community cache directory or kernels community repo + lockfile_path = os.path.join( + envs.SGLANG_CACHE_DIR.get(), DEFAULT_FA3_KERNEL_LOCKFILE + ) + + try: + from kernels import get_kernel, load_kernel + + # When the lock file provided, load from the kernel cache directory, + # otherwise, load from the repo, which require download from huggingface hub + # but always works as long as the repo is accessible. + if os.path.exists(lockfile_path): + ops = load_kernel(SGL_FA3_KERNEL_REPO, lockfile_path) + else: + ops = get_kernel(SGL_FA3_KERNEL_REPO, revision=SGL_FA3_KERNEL_REVISION) + + return { + "flash_attn_with_kvcache": ops.flash_attn_with_kvcache, + "flash_attn_varlen_func": ops.flash_attn_varlen_func, + } + except Exception as e: + # When the kernels from the repo or the cache directory cannot be loaded + # we catch the exception and log a warning, and then fallback to the implementation + # from sgl-kernel, which is expected to be less efficient but more compatible. + logger.warning( + f"Rollback to implementation from sgl-kernel since loading FlashAttention v3 " + f"kernels from {SGL_FA3_KERNEL_REPO} with lockfile {lockfile_path} failed: {e}" + ) + return _load_fa3_kernel_from_sgl() + + +def _load_fa3_kernel_from_sgl(): + from sgl_kernel.flash_attn import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + + return { + "flash_attn_with_kvcache": flash_attn_with_kvcache, + "flash_attn_varlen_func": flash_attn_varlen_func, + } + + +@cache_once +def _is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +@debug_kernel_api +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, +): + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + + return _load_fa3_kernels()["flash_attn_with_kvcache"]( + q, + k_cache, + v_cache, + k, + v, + qv, + rotary_cos, + rotary_sin, + cache_seqlens, + cache_batch_idx, + cache_leftpad, + page_table, + cu_seqlens_q, + cu_seqlens_k_new, + max_seqlen_q, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size, + attention_chunk, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) + + +@debug_kernel_api +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=None, + max_seqlen_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + attention_chunk=0, + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, + sinks=None, +): + + if not _is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + return _load_fa3_kernels()["flash_attn_varlen_func"]( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q, + seqused_k, + page_table, + softmax_scale, + causal, + qv, + q_descale, + k_descale, + v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + sm_margin, + return_softmax_lse, + sinks, + ) diff --git a/python/sglang/jit_kernel/flash_attention_v4.py b/python/sglang/jit_kernel/flash_attention_v4.py index 0a79614ee075..46b49d177388 100644 --- a/python/sglang/jit_kernel/flash_attention_v4.py +++ b/python/sglang/jit_kernel/flash_attention_v4.py @@ -42,7 +42,6 @@ def flash_attn_varlen_func( score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_softmax_lse: bool = False, - **_: object, ): if _flash_attn_varlen_func is None: # pragma: no cover raise ImportError( diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_3.py b/python/sglang/jit_kernel/tests/test_flash_attention_3.py new file mode 100644 index 000000000000..e4687da9c827 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_flash_attention_3.py @@ -0,0 +1,1373 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py +import itertools +import math +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +apply_rotary_emb = None + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=900, suite="nightly-kernel-1-gpu", nightly=True) + + +def is_hopper(): + # Only Hopper supports different V headdim + return torch.cuda.get_device_properties(0).major == 9 + + +def is_fa3_supported(device=None) -> bool: + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return (torch.version.cuda >= "12.3") and ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) + + +DISABLE_BACKWARD = True +# For CI test, we close them to True. +# DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +# DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +# DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +# DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +# DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +# DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +# DISABLE_FP8 = ( +# os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" +# or torch.cuda.get_device_capability("cuda")[0] < 9 +# ) + +DISABLE_SPLIT = False +DISABLE_PAGEDKV = True +DISABLE_APPENDKV = False +DISABLE_LOCAL = False +DISABLE_SOFTCAP = True +DISABLE_PACKGQA = False +DISABLE_FP16 = True +DISABLE_FP8 = True + + +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = ( + (attention_mask + unused_mask) if unused_mask is not None else attention_mask + ) + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False +): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + elif mode == "third": + lengths = torch.randint( + max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device + ) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) + < lengths + ) + return padding_mask + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], + col_idx >= sink_token_length, + ), + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + sinks: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if sinks is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + sinks = rearrange(sinks, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(sinks, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + sinks - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0 + ) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, + add_unused_qkv=False, + query_unused_mask=None, + key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, + query_padding_mask, + query_unused_mask, + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) + else: + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize( +# "causal,local", +# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []), +# ) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +@pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize( + "seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True] +) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize( + "rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False] +) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize( + "rotary_fraction", + ( + [0.0, 0.5, 1.0] + if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) + else [0.0] + ), +) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize( + "page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []) +) +# @pytest.mark.parametrize("page_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, + has_sink, +): + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn or not is_hopper(): + # for fp8 and ampere arch, we not support v head dim != qk head dim + dv_vals = [d] + for dv in dv_vals: + has_qv = d == 64 and dv >= 256 + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + ( + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + ) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + key_leftpad=cache_leftpad, + sinks=sinks, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True, + sinks=sinks, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + mha_type, + dtype, + has_sink, +): + from sglang.jit_kernel.flash_attention import flash_attn_varlen_func + + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + sinks=sinks, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse, *rest = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + return_softmax_lse=True, + sinks=sinks, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_4.py b/python/sglang/jit_kernel/tests/test_flash_attention_4.py index e1453b8f2323..81b0f0b23d62 100644 --- a/python/sglang/jit_kernel/tests/test_flash_attention_4.py +++ b/python/sglang/jit_kernel/tests/test_flash_attention_4.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from sglang.jit_kernel.flash_attention_v4 import flash_attn_varlen_func +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=120, suite="stage-b-kernel-unit-1-gpu-large") @@ -826,6 +826,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): sinks=learnable_sink, # FA4 uses learnable_sink, not sinks pack_gqa=pack_gqa, return_softmax_lse=True, + ver=4, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -1384,6 +1385,7 @@ def test_flash_attn_kvcache( softcap=0.0, pack_gqa=None, return_softmax_lse=True, + ver=4, ) if varlen_q: out = output_pad_fn(out) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 9c30a9798283..31372e2e16ce 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -5,27 +5,13 @@ import torch +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, ) -try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - -except ImportError as e: - raise e - def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -207,7 +193,7 @@ def flash_attn_varlen_func_op( "flash_attn_varlen_func_op is out-only op; return_softmax_lse must be False. " "Use flash_attn_varlen_func_op_lse for (out, lse)." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -271,7 +257,7 @@ def flash_attn_varlen_func_op_lse( "flash_attn_varlen_func_op_lse is out+lse op; return_softmax_lse must be True. " "Use flash_attn_varlen_func_op for out-only." ) - return flash_attn_func( + return flash_attn_varlen_func( q, k, v, @@ -409,7 +395,7 @@ def forward( # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) # - fa_ver == 4: call custom ops with FIXED return schema if fa_ver == 3: - flash_attn_op = flash_attn_func + flash_attn_op = flash_attn_varlen_func output = flash_attn_op( q=query, k=key, diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index f9d376e959be..201123324068 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -21,6 +21,7 @@ from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.environ import envs from sglang.srt.utils.common import is_npu logger = logging.getLogger(__name__) @@ -393,9 +394,7 @@ def configure_post_pass(self): self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - base_cache_dir = os.path.expanduser( - os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - ) + base_cache_dir = envs.SGLANG_CACHE_DIR.get() cache_hash = self.compiler_manager.compute_hash() cache_dir = os.path.join( diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index dfc5507de0ba..e1d8a56e6cb4 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -406,6 +406,9 @@ class Envs: # sgl-kernel SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) + # Flash Attention + SGLANG_USE_SGL_FA3_KERNEL = EnvBool(True) + # vLLM dependencies (TODO: they have been deprecated, we can remove them safely) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) @@ -531,6 +534,9 @@ class Envs: # Elastic EP Backup Port SGLANG_BACKUP_PORT_BASE = EnvInt(10000) + # Sglang Cache Dir + SGLANG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/sglang")) + envs = Envs() EnvField._allow_set_name = False diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index e522fbe4a934..a84015a803f8 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -9,13 +9,16 @@ import torch import torch.nn.functional as F -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, ) +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index ff170c390838..ad7f59c0d539 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -27,6 +27,11 @@ from sgl_kernel import merge_state_v2 +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) + @dataclass class FlashAttentionMetadata: @@ -616,9 +621,6 @@ def forward_extend( and not is_swa_layer ) - flash_attn_varlen_func = self.flash_attn_varlen_func - flash_attn_with_kvcache = self.flash_attn_with_kvcache - kwargs = {} if sinks is not None: kwargs["sinks"] = sinks @@ -696,6 +698,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -723,6 +726,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) @@ -750,6 +754,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) o, _ = merge_state_v2_wrapper( @@ -789,6 +794,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=False, return_softmax_lse=True, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -814,6 +820,7 @@ def _fa_cp_attn( softmax_scale=layer.scaling, causal=True, return_softmax_lse=forward_batch.mha_return_lse, + ver=self.fa_impl_ver, **kwargs, ) if forward_batch.mha_return_lse: @@ -822,7 +829,7 @@ def _fa_cp_attn( return output, lse return output else: - assert self.fa_impl_ver in [3], "Only FA3 support here" + assert self.fa_impl_ver == 3, "Only FA3 support here" # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id @@ -865,6 +872,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -887,6 +895,7 @@ def _fa_cp_attn( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) ) o, _ = merge_state_v2_wrapper( @@ -964,8 +973,6 @@ def forward_decode( if sinks is not None: kwargs["sinks"] = sinks - flash_attn_with_kvcache = self.flash_attn_with_kvcache - k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, @@ -1009,6 +1016,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) elif use_local_attn: @@ -1029,6 +1037,7 @@ def forward_decode( k_descale=k_descale, v_descale=v_descale, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) else: @@ -1066,6 +1075,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) if use_cascade_attn: @@ -1088,6 +1098,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, **kwargs, ) ) @@ -1144,6 +1155,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states num_splits=self.num_splits, + ver=self.fa_impl_ver, ) if use_cascade_attn: o, softmax_lse, *rest = result @@ -1165,6 +1177,7 @@ def forward_decode( v_descale=v_descale, return_softmax_lse=True, num_splits=self.num_splits, + ver=self.fa_impl_ver, ) o, _ = merge_state_v2( o, diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 862488e5f918..314c897ab313 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -61,7 +61,10 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) else: - from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) # Reuse this workspace buffer across all NSA backend instances diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 3fd45aac0101..ddda8147f969 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -38,21 +38,9 @@ if _is_cuda: from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache - try: - from sgl_kernel.flash_attn import flash_attn_varlen_func - - def flash_attn_func(*args, ver: int = 3, **kwargs): - if ver == 4: - from sglang.jit_kernel.flash_attention_v4 import ( - flash_attn_varlen_func as flash_attn_varlen_func_fa4, - ) - - return flash_attn_varlen_func_fa4(*args, **kwargs) - return flash_attn_varlen_func(*args, **kwargs) - - except ImportError as e: - raise e - + from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + ) if _is_npu: import torch_npu @@ -408,7 +396,7 @@ def forward( """ if envs.SGLANG_VIT_ENABLE_CUDA_GRAPH.get(): max_seqlen = cu_seqlens[1] - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -423,7 +411,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, @@ -474,7 +462,7 @@ def forward( seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() - output = flash_attn_func( + output = flash_attn_varlen_func( q, k, v, diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py index 4a40d25ee8c9..77e773d88d0c 100644 --- a/python/sglang/srt/layers/attention/xpu_backend.py +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -20,7 +20,11 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sgl_kernel import merge_state_v2 -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +from sglang.jit_kernel.flash_attention import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) class XPUAttentionBackend(AttentionBackend): diff --git a/python/sglang/srt/utils/runai_utils.py b/python/sglang/srt/utils/runai_utils.py index 0424a6371bde..dd74efb6626d 100644 --- a/python/sglang/srt/utils/runai_utils.py +++ b/python/sglang/srt/utils/runai_utils.py @@ -5,6 +5,8 @@ import os from pathlib import Path +from sglang.srt.environ import envs + logger = logging.getLogger(__name__) SUPPORTED_SCHEMES = ["s3://", "gs://", "az://"] @@ -26,12 +28,6 @@ # This avoids file locks, race conditions, and duplicate downloads -def get_cache_dir() -> str: - # Expand user path (~) to ensure absolute paths for locking - path = os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") - return os.path.expanduser(path) - - def list_safetensors(path: str = "") -> list[str]: """ List full file names from object path and filter by allow pattern. @@ -122,7 +118,7 @@ def get_path(cls, model_path: str) -> str: Returns the local directory path. """ model_hash = hashlib.sha256(str(model_path).encode()).hexdigest()[:16] - base_dir = get_cache_dir() + base_dir = envs.SGLANG_CACHE_DIR.get() # Ensure base cache dir exists os.makedirs(os.path.join(base_dir, "model_streamer"), exist_ok=True) diff --git a/scripts/ci/cuda/ci_install_dependency.sh b/scripts/ci/cuda/ci_install_dependency.sh index 5bfbea04ffeb..c10a79e62222 100755 --- a/scripts/ci/cuda/ci_install_dependency.sh +++ b/scripts/ci/cuda/ci_install_dependency.sh @@ -358,6 +358,10 @@ mark_step_done "Fix other dependencies" # can delete the .pth file without reliably recreating it (pip race condition). $PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true +# Download kernels from kernels community +kernels download python || true +kernels lock python || true +mv python/kernels.lock ${HOME}/.cache/sglang || true # Install human-eval pip install "setuptools==70.0.0" diff --git a/test/srt/cpu/test_flash_attn.py b/test/srt/cpu/test_flash_attn.py index 8b1faa98b5cb..4e1968fa06e7 100644 --- a/test/srt/cpu/test_flash_attn.py +++ b/test/srt/cpu/test_flash_attn.py @@ -1,15 +1,12 @@ import unittest -import sgl_kernel # noqa: F401 import torch import torch.nn.functional as F from utils import parametrize, precision +from sglang.jit_kernel.flash_attention import flash_attn_varlen_func from sglang.test.test_utils import CustomTestCase -flash_attn_varlen_func = torch.ops.sgl_kernel.flash_attn_varlen_func - - torch.manual_seed(1234)