diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 6fefea9d43d..1ca46705466 100755 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -121,12 +121,12 @@ set(SGL_KERNEL_CUDA_FLAGS # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" ) -option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) -option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF) -option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) -option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) -option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF) -option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF) +option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) +option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF) +option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) +option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) +option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF) +option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF) if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS @@ -233,7 +233,7 @@ install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel) # ============================ Optional Install ============================= # # set flash-attention sources file -# BF16 source files +# Now FA3 support sm80/sm86/sm90 if (SGL_KERNEL_ENABLE_FA3) set(SGL_FLASH_KERNEL_CUDA_FLAGS "-DNDEBUG" @@ -241,6 +241,8 @@ if (SGL_KERNEL_ENABLE_FA3) "-O3" "-Xcompiler" "-fPIC" + "-gencode=arch=compute_80,code=sm_80" + "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_90a,code=sm_90a" "-std=c++17" "-DCUTE_USE_PACKED_TUPLE=1" @@ -256,6 +258,10 @@ if (SGL_KERNEL_ENABLE_FA3) "-Xcompiler=-fno-strict-aliasing" ) + # SM8X Logic + file(GLOB FA3_SM8X_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu") + file(GLOB FA3_BF16_GEN_SRCS "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") file(GLOB FA3_BF16_GEN_SRCS_ @@ -276,7 +282,7 @@ if (SGL_KERNEL_ENABLE_FA3) "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) - set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS}) + set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS}) set(FLASH_SOURCES "csrc/flash_extension.cc" @@ -297,7 +303,7 @@ if (SGL_KERNEL_ENABLE_FA3) install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") target_compile_definitions(flash_ops PRIVATE - FLASHATTENTION_DISABLE_SM8x + # FLASHATTENTION_DISABLE_SM8x FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT FLASHATTENTION_DISABLE_UNEVEN_K diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 423394f36d7..919f4358074 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -41,6 +41,14 @@ Third-party libraries: - [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) - [FlashAttention](https://github.com/Dao-AILab/flash-attention) +### FlashAttention FYI + + FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89. + + The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x. + + And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. Thats mean if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3. + ### Kernel Development Steps to add a new kernel: diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index 2d1a79489ca..e849d0df002 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -10,10 +10,18 @@ def is_fa3_supported(device=None) -> bool: - # now sgl-kernel only build fa3 for sm90a && cuda >= 12.3 - return (torch.cuda.get_device_capability(device)[0] == 9) and ( - torch.version.cuda >= "12.3" - ) + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. + return ( + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) and (torch.version.cuda >= "12.3") def maybe_contiguous(x): diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index ae061b1bf62..2885dbb4bd4 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -11,17 +11,24 @@ apply_rotary_emb = None +def is_hopper(): + # Only Hopper supports different V headdim + return torch.cuda.get_device_properties(0).major >= 9 + + def is_fa3_supported(device=None) -> bool: - # FA3 can fail without a enough shared memory for a some shapes, currently - # only 8.0 and 8.7 have enough shared memory for all shapes + # There some fa3 FYI + # FA3 can fail without a enough shared memory for a some shapes, such as higher + # hidden_dim or some special cases. + # Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different + # Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x - # now sgl-kernel only build fa3 for sm90a && cuda >= 12.4 + # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. + # Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. return ( - (torch.cuda.get_device_capability(device)[0] == 9) - and (torch.version.cuda >= "12.4") - # or torch.cuda.get_device_capability(device) == (8, 0) - # or torch.cuda.get_device_capability(device) == (8, 7) - ) + torch.cuda.get_device_capability(device)[0] == 9 + or torch.cuda.get_device_capability(device)[0] == 8 + ) and (torch.version.cuda >= "12.3") DISABLE_BACKWARD = True @@ -558,7 +565,8 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or not is_hopper(): + # for fp8 and ampere arch, we not support v head dim != qk head dim dv_vals = [d] for dv in dv_vals: has_qv = d == 64 and dv >= 256