From 2112e0e58d99620a00f9bffc651061f23220fab4 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 14 Apr 2025 22:06:33 +0800 Subject: [PATCH 1/3] kernel: support merge_attn_states cuda kernel --- sgl-kernel/CMakeLists.txt | 15 +- .../csrc/attention/merge_attn_states.cu | 191 +++++++++ sgl-kernel/csrc/common_extension.cc | 4 + sgl-kernel/include/sgl_kernel_ops.h | 7 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/attention.py | 41 +- sgl-kernel/tests/test_merge_attn_states.py | 404 ++++++++++++++++++ test/srt/parse_results.py | 5 +- 8 files changed, 655 insertions(+), 13 deletions(-) create mode 100644 sgl-kernel/csrc/attention/merge_attn_states.cu create mode 100644 sgl-kernel/tests/test_merge_attn_states.py diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index ab0b4853f28..b73779d1124 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -91,10 +91,10 @@ set(SGL_KERNEL_CUDA_FLAGS "-O3" "-Xcompiler" "-fPIC" - "-gencode=arch=compute_75,code=sm_75" - "-gencode=arch=compute_80,code=sm_80" + # "-gencode=arch=compute_75,code=sm_75" + # "-gencode=arch=compute_80,code=sm_80" "-gencode=arch=compute_89,code=sm_89" - "-gencode=arch=compute_90,code=sm_90" + # "-gencode=arch=compute_90,code=sm_90" "-std=c++17" "-DFLASHINFER_ENABLE_F16" "-DCUTE_USE_PACKED_TUPLE=1" @@ -130,10 +130,10 @@ else() endif() if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) - set(SGL_KERNEL_ENABLE_FA3 ON) - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_90a,code=sm_90a" - ) + # set(SGL_KERNEL_ENABLE_FA3 ON) + # list(APPEND SGL_KERNEL_CUDA_FLAGS + # "-gencode=arch=compute_90a,code=sm_90a" + # ) endif() if (SGL_KERNEL_ENABLE_BF16) @@ -164,6 +164,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE set(SOURCES "csrc/allreduce/custom_all_reduce.cu" "csrc/attention/cascade.cu" + "csrc/attention/merge_attn_states.cu" "csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu" "csrc/elementwise/activation.cu" diff --git a/sgl-kernel/csrc/attention/merge_attn_states.cu b/sgl-kernel/csrc/attention/merge_attn_states.cu new file mode 100644 index 00000000000..65bd0682cd4 --- /dev/null +++ b/sgl-kernel/csrc/attention/merge_attn_states.cu @@ -0,0 +1,191 @@ +// Adapted from https://github.com/vllm-project/vllm/pull/16173 +#include +#include + +#include +#include + +#include "pytorch_extension_utils.h" + +// Helper functions to convert between different data types +// (float, half, bfloat16) for the merge attention states kernel. +inline __device__ float to_float(float u) { + return u; +} +inline __device__ float to_float(half u) { + return __half2float(u); +} +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} +inline __device__ void from_float(float& d, float s) { + d = s; +} +inline __device__ void from_float(half& d, float s) { + d = __float2half(s); +} +inline __device__ void from_float(__nv_bfloat16& d, float s) { + d = __float2bfloat16(s); +} + +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +// can be used to combine partial attention results (in the split-KV case) +// Reference blog: https://zhuanlan.zhihu.com/p/1892966682634473987 +template +__global__ void merge_attn_states_kernel( + scalar_t* output, + float* output_lse, + const scalar_t* prefix_output, + const float* prefix_lse, + const scalar_t* suffix_output, + const float* suffix_lse, + const uint num_tokens, + const uint num_heads, + const uint head_size) { + using pack_128b_t = uint4; + const uint pack_size = 16 / sizeof(scalar_t); + const uint threads_per_head = head_size / pack_size; + + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; + const uint token_head_threads = num_tokens * num_heads * threads_per_head; + + if (global_idx >= token_head_threads) return; + + // global_idx -> token_idx + head_idx + pack_idx + const uint token_head_idx = global_idx / threads_per_head; + const uint pack_idx = global_idx % threads_per_head; + + const uint token_idx = token_head_idx / num_heads; + const uint head_idx = token_head_idx % num_heads; + + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. + const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size; + const scalar_t* prefix_head_ptr = prefix_output + head_offset; + const scalar_t* suffix_head_ptr = suffix_output + head_offset; + scalar_t* output_head_ptr = output + head_offset; + + float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; + float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; + s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; + + const float max_lse = fmaxf(p_lse, s_lse); + p_lse = p_lse - max_lse; + s_lse = s_lse - max_lse; + const float p_se = expf(p_lse); + const float s_se = expf(s_lse); + const float out_se = p_se + s_se; + const float p_scale = p_se / out_se; + const float s_scale = s_se / out_se; + + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast(prefix_head_ptr)[pack_offset / pack_size]; + pack_128b_t s_out_pack = reinterpret_cast(suffix_head_ptr)[pack_offset / pack_size]; + pack_128b_t o_out_pack; + +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + // Always use float for FMA to keep high precision. + // half(uint16_t), bfloat16, float -> float. + const float p_out_f = to_float(reinterpret_cast(&p_out_pack)[i]); + const float s_out_f = to_float(reinterpret_cast(&s_out_pack)[i]); + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); + // float -> half(uint16_t), bfloat16, float. + from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); + } + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = o_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + float out_lse = logf(out_se) + max_lse; + output_lse[head_idx * num_tokens + token_idx] = out_lse; + } +} + +// The following macro is used to dispatch the conversion function based on +// the output data type. The FN is a macro that calls a function with +// template. +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ + { \ + if (scalar_dtype == at::ScalarType::Float) { \ + fn(float); \ + } else if (scalar_dtype == at::ScalarType::Half) { \ + fn(half); \ + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ + fn(__nv_bfloat16); \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ + } \ + } + +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + merge_attn_states_kernel<<>>( \ + reinterpret_cast(output.data_ptr()), \ + output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), \ + num_tokens, \ + num_heads, \ + head_size); \ + } + +/*@brief Merges the attention states from prefix and suffix + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d + * + * @param output [n,h,d] The output tensor to store the merged attention states. + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. + * @param prefix_output [n,h,d] The prefix attention states. + * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * states. + * @param suffix_output [n,h,d] The suffix attention states. + * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * states. + */ +template +void merge_attn_states_launcher( + at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + std::optional output_lse, // [NUM_HEADS, NUM_TOKENS] + const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& prefix_lse, // [NUM_HEADS, NUM_TOKENS] + const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& suffix_lse // [NUM_HEADS, NUM_TOKENS] +) { + constexpr uint NUM_THREADS = 128; + const uint num_tokens = output.size(0); + const uint num_heads = output.size(1); + const uint head_size = output.size(2); + const uint pack_size = 16 / sizeof(scalar_t); + TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + float* output_lse_ptr = nullptr; + if (output_lse.has_value()) { + output_lse_ptr = output_lse.value().data_ptr(); + } + // process one pack elements per thread. float -> 4, half/bf16 -> 8 + const uint threads_per_head = head_size / pack_size; + const uint total_threads = num_tokens * num_heads * threads_per_head; + + dim3 block(NUM_THREADS); + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); +} + +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ + { merge_attn_states_launcher(output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse); } + +void merge_attn_states( + at::Tensor& output, + std::optional output_lse, + const at::Tensor& prefix_output, + const at::Tensor& prefix_lse, + const at::Tensor& suffix_output, + const at::Tensor& suffix_lse) { + DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index a8370d89311..fccd14aebf6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -47,6 +47,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); m.impl("merge_state", torch::kCUDA, &merge_state); + m.def( + "merge_attn_states(Tensor! output, Tensor!? output_lse, Tensor prefix_output, Tensor prefix_lse, Tensor " + "suffix_output, Tensor suffix_lse) -> ()"); + m.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " "page_table, Tensor workspace) -> ()"); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 64e530295e7..f2090d9b337 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -89,6 +89,13 @@ void lightning_attention_decode( torch::Tensor new_kv); void merge_state( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); +void merge_attn_states( + at::Tensor& output, + std::optional output_lse, + const at::Tensor& prefix_output, + const at::Tensor& prefix_lse, + const at::Tensor& suffix_output, + const at::Tensor& suffix_lse); void cutlass_mla_decode( torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index f8a5b35e806..7845b8fede9 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -15,6 +15,7 @@ cutlass_mla_decode, cutlass_mla_get_workspace_size, lightning_attention_decode, + merge_attn_states, merge_state, ) from sgl_kernel.elementwise import ( diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index b8d6bce75ee..40bc394f154 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch @@ -10,16 +10,49 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def merge_state( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor + v_a: torch.Tensor, + s_a: torch.Tensor, + v_b: torch.Tensor, + s_b: torch.Tensor, + v_merged: Optional[torch.Tensor] = None, + s_merged: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: s_a = s_a.to(torch.float32) s_b = s_b.to(torch.float32) - v_merged = torch.empty_like(v_a) - s_merged = torch.empty_like(s_a) + # Avoid creating new tensors if they are already provided + if v_merged is None: + v_merged = torch.empty_like(v_a) + if s_merged is None: + s_merged = torch.empty_like(s_a) torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) return v_merged, s_merged +def merge_attn_states( + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output: Optional[torch.Tensor] = None, + output_lse: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compare with merge_state function: + # prefix_output: v_a, prefix_lse: s_a + # suffix_output: v_b, suffix_lse: s_b + # output: v_merged, output_lse: s_merged + # TODO(DefTruth): Currently, the custom merge_attn_states kernel + # does not support the FP8 data type and non - CUDA devices. + # It may be necessary to fall back to using the Triton kernel. + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + torch.ops.sgl_kernel.merge_attn_states.default( + output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse + ) + return output, output_lse + + def cutlass_mla_decode( q_nope_and_q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, diff --git a/sgl-kernel/tests/test_merge_attn_states.py b/sgl-kernel/tests/test_merge_attn_states.py new file mode 100644 index 00000000000..897d2956305 --- /dev/null +++ b/sgl-kernel/tests/test_merge_attn_states.py @@ -0,0 +1,404 @@ +from typing import Optional + +import pytest +import torch +import triton +import triton.language as tl +from sgl_kernel import merge_attn_states as merge_attn_states_cuda_v2 +from sgl_kernel import merge_state as merge_attn_states_cuda_v1 + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_merge_attn_states.py +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged + output_lse, # [NUM_HEADS, NUM_TOKENS] s_merged + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a + prefix_lse, # [NUM_HEADS, NUM_TOKENS] s_a + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b + suffix_lse, # [NUM_HEADS, NUM_TOKENS] s_b + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = tl.exp(p_lse) + tl.exp(s_lse) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) + + +def merge_attn_states_triton( + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output: Optional[torch.Tensor] = None, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + # Avoid creating new tensors if they are already provided + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + return output, output_lse + + +# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states_torch( + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] +): + # Avoid creating new tensors if they are already provided + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + p_lse = prefix_lse + s_lse = suffix_lse + # inf -> -inf + p_lse[p_lse == torch.inf] = -torch.inf + s_lse[s_lse == torch.inf] = -torch.inf + # max_lse [NUM_HEADS, NUM_TOKENS] + max_lse = torch.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + p_lse_exp = torch.exp(p_lse) + s_lse_exp = torch.exp(s_lse) + out_se = p_lse_exp + s_lse_exp + if output_lse is not None: + output_lse = torch.log(out_se) + max_lse + p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] + s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] + p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + output = prefix_output * p_scale + suffix_output * s_scale + return output, output_lse + + +NUM_BATCH_TOKENS = [256, 512, 613, 1536, 1724, 4096] +NUM_QUERY_HEADS = [8, 16, 32] +HEAD_SIZES = [128] +DTYPES = [torch.half, torch.bfloat16] + +OUTPUT_LSE = True +all_case_info: list[tuple] = [] + + +def generate_markdown_table(): + global all_case_info, OUTPUT_LSE + table_header = ( + "| tokens | heads | headsize | dtype " + "| device | torch | triton | cuda v1 | cuda v2 | speedup(vs triton) | speedup(vs v1)|" + ) + table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" + + def shortly_dtype(dtype: torch.dtype) -> str: + return str(dtype).removeprefix("torch.") + + def shortly_device(device: str) -> str: + return device.removeprefix("NVIDIA").strip() + + print(table_header) + print(table_separator) + for info in all_case_info: + ( + num_tokens, + num_heads, + head_size, + dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_v1_kernel, + avg_time_cuda_v2_kernel, + ) = info + dtype = shortly_dtype(dtype) + device = shortly_device(device) + performance_improved_triton = avg_time_triton_kernel / avg_time_cuda_v2_kernel + performance_improved_cuda_v1 = avg_time_cuda_v1_kernel / avg_time_cuda_v2_kernel + print( + f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {avg_time_torch_kernel:.4f}ms " + f"| {avg_time_triton_kernel:.4f}ms " + f"| {avg_time_cuda_v1_kernel:.4f}ms " + f"| {avg_time_cuda_v2_kernel:.4f}ms " + f"| {performance_improved_triton:.4f}x " + f"| {performance_improved_cuda_v1:.4f}x |" + ) + print(f"\nOUTPUT_LSE: {OUTPUT_LSE}") + + +@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) +@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("output_dtype", DTYPES) +@torch.inference_mode() +def test_merge_attn_states( + num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype +): + if not torch.cuda.is_available(): + pytest.skip( + "Currently only support compare triton merge_attn_states " + "with custom cuda merge_attn_states kernel" + ) + + NUM_TOKENS = num_tokens + NUM_HEADS = num_query_heads + HEAD_SIZE = head_size + + print( + f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {torch.cuda.get_device_name()}" + ) + + # prefix_lse and suffix_lse contain inf and normal values + prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + + # Generate boolean masks + mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 + mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 + # Ensure that the same position is not True at the same time + combined_mask = torch.logical_and(mask_prefix, mask_suffix) + mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) + mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) + + prefix_lse[mask_prefix] = float("inf") + suffix_lse[mask_suffix] = float("inf") + + # Other input tensors (need to be initialized but + # no actual calculation needed) + output = torch.zeros( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + output_lse = torch.zeros( + (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" + ) + prefix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + suffix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + + warmup_times = 2 + repeat_times = 20 + + def perf_kernel_fn( + output_fn: torch.Tensor, + output_lse_fn: torch.Tensor, + kernel_fn: callable, + fn_type: str = "torch", + ): + # TODO: align CUDA v1 and CUDA v2 API + total_time = 0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + # Avoid inplace inf -> -inf, we have to use prefix_lse + # and suffix_lse for other kernel. + if fn_type == "torch": + prefix_lse_ = prefix_lse.clone() + suffix_lse_ = suffix_lse.clone() + else: + prefix_lse_ = prefix_lse + suffix_lse_ = suffix_lse + + if fn_type == "cuda_v1": + prefix_lse_ = prefix_lse_.transpose(0, 1).contiguous() + suffix_lse_ = suffix_lse_.transpose(0, 1).contiguous() + output_lse_fn = output_lse_fn.transpose(0, 1).contiguous() + if output_dtype not in (torch.half, torch.bfloat16): + return 0, output_fn, output_lse_fn + + for _ in range(warmup_times): + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) + torch.cuda.synchronize() + + for _ in range(repeat_times): + start.record() + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) + end.record() + torch.cuda.synchronize() + total_time += start.elapsed_time(end) + + avg_time = total_time / repeat_times + return avg_time, output_fn, output_lse_fn + + # 0. Run the Torch kernel + output_torch = output.clone() + output_lse_torch = output_lse.clone() if OUTPUT_LSE else None + avg_time_torch_kernel, output_torch, output_lse_torch = perf_kernel_fn( + output_torch, output_lse_torch, merge_attn_states_torch, fn_type="torch" + ) + + # 1. Run the Triton kernel + output_ref_triton = output.clone() + output_lse_ref_triton = output_lse.clone() if OUTPUT_LSE else None + avg_time_triton_kernel, output_ref_triton, output_lse_ref_triton = perf_kernel_fn( + output_ref_triton, + output_lse_ref_triton, + merge_attn_states_triton, + fn_type="triton", + ) + + # 2. Run the CUDA V1 kernel + output_cuda_v1 = output.clone() + output_lse_cuda_v1 = output_lse.clone() if OUTPUT_LSE else None + avg_time_cuda_v1_kernel, output_cuda_v1, output_lse_cuda_v1 = perf_kernel_fn( + output_cuda_v1, output_lse_cuda_v1, merge_attn_states_cuda_v1, fn_type="cuda_v1" + ) + + # 3. Run the CUDA V2 kernel + output_cuda_v2 = output.clone() + output_lse_cuda_v2 = output_lse.clone() if OUTPUT_LSE else None + avg_time_cuda_v2_kernel, output_cuda_v2, output_lse_cuda_v2 = perf_kernel_fn( + output_cuda_v2, output_lse_cuda_v2, merge_attn_states_cuda_v2, fn_type="cuda_v2" + ) + + # 4. Performance compare + performance_improved = avg_time_triton_kernel / avg_time_cuda_v2_kernel + print(f" Torch time: {avg_time_torch_kernel:.6f}ms") + print(f"Triton time: {avg_time_triton_kernel:.6f}ms") + print(f"CUDA v1 time: {avg_time_cuda_v1_kernel:.6f}ms") + print( + f"CUDA v2 time: {avg_time_cuda_v2_kernel:.6f}ms, " + f"Performance: {performance_improved:.5f}x" + ) + print("-" * 100) + + # 5. Correctness compare + # Liger Kernel: Efficient Triton Kernels for LLM Training + # https://arxiv.org/pdf/2410.10989, 3.3 Correctness + # use rtol = 1e-2 for bfloat16. + rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3 + + def diff(a: torch.Tensor, b: torch.Tensor): + max_diff = torch.max(torch.abs(a.float() - b.float())) + return max_diff + + # Use Triton output as reference because we want to replace + # the Triton kernel with custom CUDA kernel for merge attn + # states operation. + output_ref = output_ref_triton + output_lse_ref = output_lse_ref_triton if OUTPUT_LSE else None + torch.testing.assert_close( + output_cuda_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol + ) + print("Output all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") + print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_cuda_v2)}") + print(f"(CUDA v2 vs Triton): {diff(output_ref, output_cuda_v2)}") + print("-" * 100) + + if OUTPUT_LSE: + torch.testing.assert_close( + output_lse_cuda_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) + print("Output LSE all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") + print(f"(CUDA vs v2 Torch) : {diff(output_lse_torch, output_lse_cuda_v2)}") + print(f"(CUDA vs v2 Triton): {diff(output_lse_ref, output_lse_cuda_v2)}") + print("-" * 100) + + print( + "All output values test passed! All inf values " + "are correctly replaced with -inf." + ) + print("-" * 100) + + device = torch.cuda.get_device_name() + all_case_info.append( + ( + NUM_TOKENS, + NUM_HEADS, + HEAD_SIZE, + output_dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_v1_kernel, + avg_time_cuda_v2_kernel, + ) + ) + if len(all_case_info) == ( + len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) + ): + generate_markdown_table() diff --git a/test/srt/parse_results.py b/test/srt/parse_results.py index 8389a4b9c2e..de1d5cf2740 100644 --- a/test/srt/parse_results.py +++ b/test/srt/parse_results.py @@ -1,7 +1,8 @@ -import json -import pandas as pd import argparse +import json import os + +import pandas as pd from tabulate import tabulate # Parse command-line arguments From 59142544d3f049df49072e72fb6bf512e79b093d Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 14 Apr 2025 22:09:38 +0800 Subject: [PATCH 2/3] kernel: support merge_attn_states cuda kernel --- sgl-kernel/tests/test_merge_attn_states.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/tests/test_merge_attn_states.py b/sgl-kernel/tests/test_merge_attn_states.py index 897d2956305..f2f619c57f8 100644 --- a/sgl-kernel/tests/test_merge_attn_states.py +++ b/sgl-kernel/tests/test_merge_attn_states.py @@ -153,7 +153,9 @@ def generate_markdown_table(): "| tokens | heads | headsize | dtype " "| device | torch | triton | cuda v1 | cuda v2 | speedup(vs triton) | speedup(vs v1)|" ) - table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" + table_separator = ( + "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" + ) def shortly_dtype(dtype: torch.dtype) -> str: return str(dtype).removeprefix("torch.") From c5a28f6e8c4f0196277b7a50971548f7d72307ea Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 15 Apr 2025 11:00:28 +0800 Subject: [PATCH 3/3] kernel: support merge_state_v2 cuda kernel --- sgl-kernel/CMakeLists.txt | 14 +- .../csrc/attention/merge_attn_states.cu | 70 +++--- sgl-kernel/csrc/common_extension.cc | 6 +- sgl-kernel/include/sgl_kernel_ops.h | 9 +- sgl-kernel/python/sgl_kernel/__init__.py | 2 +- sgl-kernel/python/sgl_kernel/attention.py | 36 ++- ..._attn_states.py => test_merge_state_v2.py} | 230 +++++++++--------- 7 files changed, 179 insertions(+), 188 deletions(-) rename sgl-kernel/tests/{test_merge_attn_states.py => test_merge_state_v2.py} (58%) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index b73779d1124..fc808d11d7c 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -91,10 +91,10 @@ set(SGL_KERNEL_CUDA_FLAGS "-O3" "-Xcompiler" "-fPIC" - # "-gencode=arch=compute_75,code=sm_75" - # "-gencode=arch=compute_80,code=sm_80" + "-gencode=arch=compute_75,code=sm_75" + "-gencode=arch=compute_80,code=sm_80" "-gencode=arch=compute_89,code=sm_89" - # "-gencode=arch=compute_90,code=sm_90" + "-gencode=arch=compute_90,code=sm_90" "-std=c++17" "-DFLASHINFER_ENABLE_F16" "-DCUTE_USE_PACKED_TUPLE=1" @@ -130,10 +130,10 @@ else() endif() if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) - # set(SGL_KERNEL_ENABLE_FA3 ON) - # list(APPEND SGL_KERNEL_CUDA_FLAGS - # "-gencode=arch=compute_90a,code=sm_90a" - # ) + set(SGL_KERNEL_ENABLE_FA3 ON) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_90a,code=sm_90a" + ) endif() if (SGL_KERNEL_ENABLE_BF16) diff --git a/sgl-kernel/csrc/attention/merge_attn_states.cu b/sgl-kernel/csrc/attention/merge_attn_states.cu index 65bd0682cd4..a3b40534008 100644 --- a/sgl-kernel/csrc/attention/merge_attn_states.cu +++ b/sgl-kernel/csrc/attention/merge_attn_states.cu @@ -1,4 +1,3 @@ -// Adapted from https://github.com/vllm-project/vllm/pull/16173 #include #include @@ -29,8 +28,6 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) { } // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 -// can be used to combine partial attention results (in the split-KV case) -// Reference blog: https://zhuanlan.zhihu.com/p/1892966682634473987 template __global__ void merge_attn_states_kernel( scalar_t* output, @@ -64,8 +61,10 @@ __global__ void merge_attn_states_kernel( const scalar_t* suffix_head_ptr = suffix_output + head_offset; scalar_t* output_head_ptr = output + head_offset; - float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; - float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + // float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; + // float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + float p_lse = prefix_lse[token_idx * num_heads + head_idx]; + float s_lse = suffix_lse[token_idx * num_heads + head_idx]; p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; @@ -102,7 +101,7 @@ __global__ void merge_attn_states_kernel( // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { float out_lse = logf(out_se) + max_lse; - output_lse[head_idx * num_tokens + token_idx] = out_lse; + output_lse[token_idx * num_heads + head_idx] = out_lse; } } @@ -126,7 +125,7 @@ __global__ void merge_attn_states_kernel( { \ merge_attn_states_kernel<<>>( \ reinterpret_cast(output.data_ptr()), \ - output_lse_ptr, \ + reinterpret_cast(output_lse.data_ptr()), \ reinterpret_cast(prefix_output.data_ptr()), \ reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ @@ -142,20 +141,20 @@ __global__ void merge_attn_states_kernel( * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. - * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * @param prefix_lse [n,h] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. - * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * @param suffix_lse [n,h] The log-sum-exp values for the suffix attention * states. */ template void merge_attn_states_launcher( - at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - std::optional output_lse, // [NUM_HEADS, NUM_TOKENS] - const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - const at::Tensor& prefix_lse, // [NUM_HEADS, NUM_TOKENS] - const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - const at::Tensor& suffix_lse // [NUM_HEADS, NUM_TOKENS] + const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS] + const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS] + at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS] ) { constexpr uint NUM_THREADS = 128; const uint num_tokens = output.size(0); @@ -163,11 +162,8 @@ void merge_attn_states_launcher( const uint head_size = output.size(2); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); - float* output_lse_ptr = nullptr; - if (output_lse.has_value()) { - output_lse_ptr = output_lse.value().data_ptr(); - } - // process one pack elements per thread. float -> 4, half/bf16 -> 8 + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; @@ -178,14 +174,28 @@ void merge_attn_states_launcher( } #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ - { merge_attn_states_launcher(output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse); } - -void merge_attn_states( - at::Tensor& output, - std::optional output_lse, - const at::Tensor& prefix_output, - const at::Tensor& prefix_lse, - const at::Tensor& suffix_output, - const at::Tensor& suffix_lse) { - DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); + { merge_attn_states_launcher(v_a, s_a, v_b, s_b, v_merged, s_merged); } + +void merge_state_v2( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { + // Input tensors must be contiguous + CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim) + CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads) + CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim) + CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads) + // v_merged output (seq_len, num_heads, head_dim) + // s_merged output_lse (seq_len, num_heads) + auto device = v_a.device(); + CHECK_EQ(s_a.device(), device); + CHECK_EQ(v_b.device(), device); + CHECK_EQ(s_b.device(), device); + CHECK_DIM(3, v_a); + CHECK_DIM(2, s_a); + CHECK_DIM(3, v_b); + CHECK_DIM(2, s_b); + CHECK_SHAPE(v_a, v_b); + CHECK_SHAPE(s_a, s_b); + CHECK_EQ(v_a.size(0), s_a.size(0)); + CHECK_EQ(v_a.size(1), s_b.size(1)); + DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); } diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index fccd14aebf6..d3e0ffae82b 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -47,10 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); m.impl("merge_state", torch::kCUDA, &merge_state); - m.def( - "merge_attn_states(Tensor! output, Tensor!? output_lse, Tensor prefix_output, Tensor prefix_lse, Tensor " - "suffix_output, Tensor suffix_lse) -> ()"); - m.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); + m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); + m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " "page_table, Tensor workspace) -> ()"); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index f2090d9b337..118a8ba058e 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -89,13 +89,8 @@ void lightning_attention_decode( torch::Tensor new_kv); void merge_state( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); -void merge_attn_states( - at::Tensor& output, - std::optional output_lse, - const at::Tensor& prefix_output, - const at::Tensor& prefix_lse, - const at::Tensor& suffix_output, - const at::Tensor& suffix_lse); +void merge_state_v2( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void cutlass_mla_decode( torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 7845b8fede9..a6338ee5aad 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -15,8 +15,8 @@ cutlass_mla_decode, cutlass_mla_get_workspace_size, lightning_attention_decode, - merge_attn_states, merge_state, + merge_state_v2, ) from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 40bc394f154..d80a6fbbd52 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -28,29 +28,27 @@ def merge_state( return v_merged, s_merged -def merge_attn_states( - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output: Optional[torch.Tensor] = None, - output_lse: Optional[torch.Tensor] = None, +def merge_state_v2( + v_a: torch.Tensor, + s_a: torch.Tensor, + v_b: torch.Tensor, + s_b: torch.Tensor, + v_merged: Optional[torch.Tensor] = None, + s_merged: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Compare with merge_state function: - # prefix_output: v_a, prefix_lse: s_a - # suffix_output: v_b, suffix_lse: s_b - # output: v_merged, output_lse: s_merged + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) # TODO(DefTruth): Currently, the custom merge_attn_states kernel # does not support the FP8 data type and non - CUDA devices. # It may be necessary to fall back to using the Triton kernel. - if output is None: - output = torch.empty_like(prefix_output) - if output_lse is None: - output_lse = torch.empty_like(prefix_lse) - torch.ops.sgl_kernel.merge_attn_states.default( - output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse - ) - return output, output_lse + + # Avoid creating new tensors if they are already provided + if v_merged is None: + v_merged = torch.empty_like(v_a) + if s_merged is None: + s_merged = torch.empty_like(s_a) + torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged) + return v_merged, s_merged def cutlass_mla_decode( diff --git a/sgl-kernel/tests/test_merge_attn_states.py b/sgl-kernel/tests/test_merge_state_v2.py similarity index 58% rename from sgl-kernel/tests/test_merge_attn_states.py rename to sgl-kernel/tests/test_merge_state_v2.py index f2f619c57f8..f5c7a30dddb 100644 --- a/sgl-kernel/tests/test_merge_attn_states.py +++ b/sgl-kernel/tests/test_merge_state_v2.py @@ -4,21 +4,17 @@ import torch import triton import triton.language as tl -from sgl_kernel import merge_attn_states as merge_attn_states_cuda_v2 -from sgl_kernel import merge_state as merge_attn_states_cuda_v1 +from sgl_kernel import merge_state, merge_state_v2 -# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 -# can be used to combine partial attention results (in the split-KV case) -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_merge_attn_states.py @triton.jit -def merge_attn_states_kernel( +def merge_state_kernel( output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged - output_lse, # [NUM_HEADS, NUM_TOKENS] s_merged + output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a - prefix_lse, # [NUM_HEADS, NUM_TOKENS] s_a + prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b - suffix_lse, # [NUM_HEADS, NUM_TOKENS] s_b + suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, @@ -28,8 +24,8 @@ def merge_attn_states_kernel( head_idx = tl.program_id(1) num_heads = tl.num_programs(1) - p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) - s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx) + s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx) p_lse = float("-inf") if p_lse == float("inf") else p_lse s_lse = float("-inf") if s_lse == float("inf") else s_lse @@ -40,7 +36,7 @@ def merge_attn_states_kernel( if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse - tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + tl.store(output_lse + token_idx * num_heads + head_idx, out_lse) head_arange = tl.arange(0, PADDED_HEAD_SIZE) head_mask = head_arange < HEAD_SIZE @@ -69,7 +65,7 @@ def merge_attn_states_kernel( ) -def merge_attn_states_triton( +def merge_state_triton( prefix_output: torch.Tensor, prefix_lse: torch.Tensor, suffix_output: torch.Tensor, @@ -87,7 +83,7 @@ def merge_attn_states_triton( if output_lse is None: output_lse = torch.empty_like(prefix_lse) - merge_attn_states_kernel[(num_tokens, num_query_heads)]( + merge_state_kernel[(num_tokens, num_query_heads)]( output, output_lse, prefix_output, @@ -101,15 +97,14 @@ def merge_attn_states_triton( return output, output_lse -# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 -# can be used to combine partial attention results (in the split-KV case) -def merge_attn_states_torch( +# Naive PyTorch Implements of Merge Attn States +def merge_state_torch( prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + prefix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS] suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS] output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] + output_lse: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS] ): # Avoid creating new tensors if they are already provided if output is None: @@ -130,28 +125,27 @@ def merge_attn_states_torch( out_se = p_lse_exp + s_lse_exp if output_lse is not None: output_lse = torch.log(out_se) + max_lse - p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] - s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] - p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + p_scale = p_lse_exp / out_se + s_scale = s_lse_exp / out_se + p_scale = p_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = s_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] output = prefix_output * p_scale + suffix_output * s_scale return output, output_lse -NUM_BATCH_TOKENS = [256, 512, 613, 1536, 1724, 4096] +NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536] NUM_QUERY_HEADS = [8, 16, 32] -HEAD_SIZES = [128] +HEAD_SIZES = [32, 48, 64, 128, 256] DTYPES = [torch.half, torch.bfloat16] -OUTPUT_LSE = True all_case_info: list[tuple] = [] def generate_markdown_table(): - global all_case_info, OUTPUT_LSE + global all_case_info table_header = ( "| tokens | heads | headsize | dtype " - "| device | torch | triton | cuda v1 | cuda v2 | speedup(vs triton) | speedup(vs v1)|" + "| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|" ) table_separator = ( "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" @@ -172,25 +166,24 @@ def shortly_device(device: str) -> str: head_size, dtype, device, - avg_time_torch_kernel, - avg_time_triton_kernel, - avg_time_cuda_v1_kernel, - avg_time_cuda_v2_kernel, + time_torch, + time_triton, + time_v1, + time_v2, ) = info dtype = shortly_dtype(dtype) device = shortly_device(device) - performance_improved_triton = avg_time_triton_kernel / avg_time_cuda_v2_kernel - performance_improved_cuda_v1 = avg_time_cuda_v1_kernel / avg_time_cuda_v2_kernel + improved_triton = time_triton / time_v2 + improved_v1 = time_v1 / time_v2 print( f"| {num_tokens} | {num_heads} | {head_size} " - f"| {dtype} | {device} | {avg_time_torch_kernel:.4f}ms " - f"| {avg_time_triton_kernel:.4f}ms " - f"| {avg_time_cuda_v1_kernel:.4f}ms " - f"| {avg_time_cuda_v2_kernel:.4f}ms " - f"| {performance_improved_triton:.4f}x " - f"| {performance_improved_cuda_v1:.4f}x |" + f"| {dtype} | {device} | {time_torch:.4f}ms " + f"| {time_triton:.4f}ms " + f"| {time_v1:.4f}ms " + f"| {time_v2:.4f}ms " + f"| {improved_triton:.4f}x " + f"| {improved_v1:.4f}x |" ) - print(f"\nOUTPUT_LSE: {OUTPUT_LSE}") @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @@ -218,12 +211,12 @@ def test_merge_attn_states( ) # prefix_lse and suffix_lse contain inf and normal values - prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") - suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + prefix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda") # Generate boolean masks - mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 - mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 + mask_prefix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1 + mask_suffix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1 # Ensure that the same position is not True at the same time combined_mask = torch.logical_and(mask_prefix, mask_suffix) mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) @@ -238,7 +231,7 @@ def test_merge_attn_states( (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" ) output_lse = torch.zeros( - (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" + (NUM_TOKENS, NUM_HEADS), dtype=torch.float32, device="cuda" ) prefix_output = torch.randn( (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" @@ -256,10 +249,6 @@ def perf_kernel_fn( kernel_fn: callable, fn_type: str = "torch", ): - # TODO: align CUDA v1 and CUDA v2 API - total_time = 0 - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) # Avoid inplace inf -> -inf, we have to use prefix_lse # and suffix_lse for other kernel. if fn_type == "torch": @@ -270,80 +259,82 @@ def perf_kernel_fn( suffix_lse_ = suffix_lse if fn_type == "cuda_v1": - prefix_lse_ = prefix_lse_.transpose(0, 1).contiguous() - suffix_lse_ = suffix_lse_.transpose(0, 1).contiguous() - output_lse_fn = output_lse_fn.transpose(0, 1).contiguous() + # merge_state v1 kernel not support float32 if output_dtype not in (torch.half, torch.bfloat16): return 0, output_fn, output_lse_fn - for _ in range(warmup_times): - output_fn, output_lse_fn = kernel_fn( - prefix_output, - prefix_lse_, - suffix_output, - suffix_lse_, - output_fn, - output_lse_fn, - ) - torch.cuda.synchronize() - - for _ in range(repeat_times): - start.record() - output_fn, output_lse_fn = kernel_fn( - prefix_output, - prefix_lse_, - suffix_output, - suffix_lse_, - output_fn, - output_lse_fn, - ) - end.record() + total_time = 0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + try: + for _ in range(warmup_times): + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) torch.cuda.synchronize() - total_time += start.elapsed_time(end) - avg_time = total_time / repeat_times - return avg_time, output_fn, output_lse_fn + for _ in range(repeat_times): + start.record() + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) + end.record() + torch.cuda.synchronize() + total_time += start.elapsed_time(end) + + avg_time = total_time / repeat_times + return avg_time, output_fn, output_lse_fn + except Exception as e: + return 0, output_fn, output_lse_fn # 0. Run the Torch kernel output_torch = output.clone() - output_lse_torch = output_lse.clone() if OUTPUT_LSE else None - avg_time_torch_kernel, output_torch, output_lse_torch = perf_kernel_fn( - output_torch, output_lse_torch, merge_attn_states_torch, fn_type="torch" + output_lse_torch = output_lse.clone() + time_torch, output_torch, output_lse_torch = perf_kernel_fn( + output_torch, output_lse_torch, merge_state_torch, fn_type="torch" ) # 1. Run the Triton kernel output_ref_triton = output.clone() - output_lse_ref_triton = output_lse.clone() if OUTPUT_LSE else None - avg_time_triton_kernel, output_ref_triton, output_lse_ref_triton = perf_kernel_fn( + output_lse_ref_triton = output_lse.clone() + time_triton, output_ref_triton, output_lse_ref_triton = perf_kernel_fn( output_ref_triton, output_lse_ref_triton, - merge_attn_states_triton, + merge_state_triton, fn_type="triton", ) - # 2. Run the CUDA V1 kernel - output_cuda_v1 = output.clone() - output_lse_cuda_v1 = output_lse.clone() if OUTPUT_LSE else None - avg_time_cuda_v1_kernel, output_cuda_v1, output_lse_cuda_v1 = perf_kernel_fn( - output_cuda_v1, output_lse_cuda_v1, merge_attn_states_cuda_v1, fn_type="cuda_v1" + # 2. Run the merge_state V1 kernel + output_v1 = output.clone() + output_lse_v1 = output_lse.clone() + time_v1, output_v1, output_lse_v1 = perf_kernel_fn( + output_v1, output_lse_v1, merge_state, fn_type="cuda_v1" ) - # 3. Run the CUDA V2 kernel - output_cuda_v2 = output.clone() - output_lse_cuda_v2 = output_lse.clone() if OUTPUT_LSE else None - avg_time_cuda_v2_kernel, output_cuda_v2, output_lse_cuda_v2 = perf_kernel_fn( - output_cuda_v2, output_lse_cuda_v2, merge_attn_states_cuda_v2, fn_type="cuda_v2" + # 3. Run the merge_state V2 kernel + output_v2 = output.clone() + output_lse_v2 = output_lse.clone() + time_v2, output_v2, output_lse_v2 = perf_kernel_fn( + output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2" ) # 4. Performance compare - performance_improved = avg_time_triton_kernel / avg_time_cuda_v2_kernel - print(f" Torch time: {avg_time_torch_kernel:.6f}ms") - print(f"Triton time: {avg_time_triton_kernel:.6f}ms") - print(f"CUDA v1 time: {avg_time_cuda_v1_kernel:.6f}ms") - print( - f"CUDA v2 time: {avg_time_cuda_v2_kernel:.6f}ms, " - f"Performance: {performance_improved:.5f}x" - ) + improved = time_triton / time_v2 + print(f" Torch time: {time_torch:.6f}ms") + print(f" Triton time: {time_triton:.6f}ms") + print(f"CUDA v1 time: {time_v1:.6f}ms") + print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x") print("-" * 100) # 5. Correctness compare @@ -360,25 +351,24 @@ def diff(a: torch.Tensor, b: torch.Tensor): # the Triton kernel with custom CUDA kernel for merge attn # states operation. output_ref = output_ref_triton - output_lse_ref = output_lse_ref_triton if OUTPUT_LSE else None + output_lse_ref = output_lse_ref_triton torch.testing.assert_close( - output_cuda_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol + output_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol ) print("Output all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") - print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_cuda_v2)}") - print(f"(CUDA v2 vs Triton): {diff(output_ref, output_cuda_v2)}") + print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_v2)}") + print(f"(CUDA v2 vs Triton): {diff(output_ref, output_v2)}") print("-" * 100) - if OUTPUT_LSE: - torch.testing.assert_close( - output_lse_cuda_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol - ) - print("Output LSE all match, max abs diff:") - print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") - print(f"(CUDA vs v2 Torch) : {diff(output_lse_torch, output_lse_cuda_v2)}") - print(f"(CUDA vs v2 Triton): {diff(output_lse_ref, output_lse_cuda_v2)}") - print("-" * 100) + torch.testing.assert_close( + output_lse_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) + print("Output LSE all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") + print(f"(CUDA v2 vs Torch) : {diff(output_lse_torch, output_lse_v2)}") + print(f"(CUDA v2 vs Triton): {diff(output_lse_ref, output_lse_v2)}") + print("-" * 100) print( "All output values test passed! All inf values " @@ -394,10 +384,10 @@ def diff(a: torch.Tensor, b: torch.Tensor): HEAD_SIZE, output_dtype, device, - avg_time_torch_kernel, - avg_time_triton_kernel, - avg_time_cuda_v1_kernel, - avg_time_cuda_v2_kernel, + time_torch, + time_triton, + time_v1, + time_v2, ) ) if len(all_case_info) == (