Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ constexpr uint32_t cvtExpansion = exactDiv(inputElemSize, cacheElemSize);
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#else
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \
Expand Down
9 changes: 6 additions & 3 deletions csrc/xqa/mha_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1966,9 +1966,12 @@ __device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gme
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) {
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size ==
exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
ret[i] = reinterpret_cast<Vec<Vec<float, GmmaAccCoreMat::cols>,
exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
Comment on lines +1877 to 1883
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Out‑of‑bounds read in loadGmemColWiseVecWithDup for attention sinks

gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols.

Apply this fix:

-    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
-    uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
-#pragma unroll
-    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
-      ret[i][j] = gmemVec[baseOffset + j];
-    }
+    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
+#pragma unroll
+    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
+      // Duplicate the same head sink across the 2 columns
+      ret[i][j] = gmemVec[clampedIdx];
+    }
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
// Duplicate the same head sink across the 2 columns
ret[i][j] = gmemVec[clampedIdx];
}

return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/xqa/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inline constexpr int32_t kBAD_PAGE_INDEX = -1;
__constant__ constexpr float kE4M3_MAX = 448.F;

#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10);
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870
constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10);
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def gen_xqa(
head_dim=head_size,
head_group_ratio=head_grp_size,
use_sliding_window=use_sliding_window,
enable_pdl=True,
)

if has_sm120 or has_sm121:
Expand All @@ -415,6 +416,7 @@ def gen_xqa(
head_dim=576,
head_group_ratio=128,
use_sliding_window=False,
enable_pdl=True,
)


Expand Down
129 changes: 129 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch

from .xqa import xqa
from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .jit import (
gen_batch_decode_mla_module,
Expand Down Expand Up @@ -2253,6 +2254,133 @@ def trtllm_batch_decode_with_kv_cache(
)


# xqa uses NHD layout
def xqa_batch_decode_with_kv_cache(
query: torch.Tensor,
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
bmm1_scale: float,
bmm2_scale: float,
window_left: int = -1,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
enable_pdl: bool = None,
q_len_per_req: Optional[int] = 1,
) -> torch.Tensor:
"""
Parameters
----------
query : torch.Tensor
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request

kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, page_size, num_kv_heads, head_dim]
If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, page_size, num_kv_heads, head_dim]

workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
workspace

block_tables : torch.Tensor
page_table of kv cache, [batch_size, num_pages]

seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``

max_seq_len : int
max sequence length for kv_cache

bmm1_scale : float
fused scale for bmm1 input.

bmm2_scale : float
fused scale for bmm2 input.

window_left : int = -1
The left (inclusive) window size for the attention window, when set to ``-1``, the window
size will be set to the full length of the sequence. Defaults to ``-1``.

out : Optional[torch.Tensor] = None
output tensor, if not provided, will be allocated with ``query.dtype``.

sinks : Optional[torch.Tensor] = None
additional value per head in the denominator of the softmax.

enable_pdl : bool
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Only supported for >= sm90, and currently only for FA2, CUDA core, and trtllm-gen decode.

Returns
-------
out : torch.Tensor
output torch.Tensor.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

assert q_len_per_req == 1, "xqa not support speculative decoding yet"

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
if kv_cache.shape[1] == 1:
k_cache, v_cache = kv_cache, kv_cache
else:
assert kv_cache.shape[1] == 2, (
"When kv_cache is a single tensor, the second dimension must be 1 or 2"
)
# NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
Comment on lines 2400 to 2409
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Squeeze the singleton KV axis before inferring shapes

When kv_cache comes in as a single tensor with shape [num_pages, 1, page_size, num_kv_heads, head_dim], assigning it directly to k_cache/v_cache preserves that extra dimension. As a result num_kv_heads = k_cache.shape[2] picks up page_size, and page_size = k_cache.shape[1] becomes 1, so every downstream stride and reshape is wrong. Strip the singleton dimension before computing the metadata.

     if isinstance(kv_cache, tuple):
         k_cache, v_cache = kv_cache
     else:
         if kv_cache.shape[1] == 1:
-            k_cache, v_cache = kv_cache, kv_cache
+            k_cache = kv_cache.squeeze(1)
+            v_cache = k_cache
         else:
             assert kv_cache.shape[1] == 2, (

# it doesn't change underlying storage
k_cache, v_cache = kv_cache.unbind(dim=1)

sm_count = get_device_sm_count(query.device)

bmm1_scale = (
bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale
)
bmm2_scale = (
bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale
)

num_kv_heads = k_cache.shape[2]
page_size = k_cache.shape[1]
head_dim = k_cache.shape[3]
workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
kv_scale_value = bmm2_scale
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)

k_cache_new = k_cache.reshape(-1, head_dim).contiguous()
v_cache_new = v_cache.reshape(-1, head_dim).contiguous()
query_new = query.unsqueeze(1).contiguous()
seq_lens_new = seq_lens.unsqueeze(1).contiguous()
sinks_new = (
sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None
)

xqa(
query_new,
k_cache_new,
v_cache_new,
block_tables,
seq_lens_new,
out,
workspace_0,
workspace_1,
num_kv_heads,
page_size,
sinks=sinks_new,
q_scale=q_scale_value,
kv_scale=torch.tensor(
[kv_scale_value], dtype=torch.float32, device=query.device
),
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
sm_count=sm_count,
)

return out


def _check_trtllm_gen_mla_shape(
query,
kv_cache,
Expand Down Expand Up @@ -2410,6 +2538,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)

return out


Expand Down
20 changes: 17 additions & 3 deletions flashinfer/jit/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def gen_xqa_module(
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool,
enable_pdl: bool,
) -> JitSpec:
if input_dtype == torch.float16:
flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
Expand Down Expand Up @@ -84,10 +85,15 @@ def gen_xqa_module(
)
sm_nvcc_flags = nvcc_flags

if enable_pdl:
flag_enable_pdl = ["-DENABLE_PDL=2"]
else:
flag_enable_pdl = ["-DENABLE_PDL=0"]

flag_mla_wrapper = ["-DMLA_WRAPPER=0"]

return gen_jit_spec(
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
[
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",
Expand All @@ -103,6 +109,7 @@ def gen_xqa_module(
+ flag_kv_cache_dtype
+ flag_head_group_ratio
+ flag_sliding_window
+ flag_enable_pdl
+ flag_mla_wrapper,
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"],
Expand All @@ -116,6 +123,7 @@ def gen_xqa_module_mla(
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool = False,
enable_pdl: bool = True,
) -> JitSpec:
assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA"
assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla"
Expand Down Expand Up @@ -145,10 +153,15 @@ def gen_xqa_module_mla(
nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[12])
sm_nvcc_flags = nvcc_flags

if enable_pdl:
flag_enable_pdl = ["-DENABLE_PDL=2"]
else:
flag_enable_pdl = ["-DENABLE_PDL=0"]

flag_mla_wrapper = ["-DMLA_WRAPPER=1"]

return gen_jit_spec(
f"xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
[
jit_env.FLASHINFER_CSRC_DIR / "xqa/mla_sm120.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp",
Expand All @@ -162,7 +175,8 @@ def gen_xqa_module_mla(
+ flag_kv_cache_dtype
+ flag_head_group_ratio
+ flag_sliding_window
+ flag_mla_wrapper,
+ flag_mla_wrapper
+ flag_enable_pdl,
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"],
)
16 changes: 12 additions & 4 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_xqa_module(
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool,
enable_pdl: bool,
):
module = gen_xqa_module(
input_dtype,
Expand All @@ -45,10 +46,11 @@ def get_xqa_module(
head_dim,
head_group_ratio,
use_sliding_window,
enable_pdl,
).build_and_load()

@register_custom_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
mutates_args=("output", "workspace_buffer"),
)
def xqa(
Expand Down Expand Up @@ -91,7 +93,7 @@ def xqa(
)

@register_fake_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}"
)
def _fake_xqa(
run_sm90_fp8_mha: bool,
Expand Down Expand Up @@ -135,6 +137,7 @@ def xqa(
kv_scale: Optional[torch.Tensor] = None,
sliding_win_size: int = 0,
sm_count: Optional[int] = None,
enable_pdl: bool = True,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.
Parameters
Expand Down Expand Up @@ -239,6 +242,7 @@ def xqa(
head_dim,
head_group_ratio,
use_sliding_window,
enable_pdl,
)
xqa_module.xqa(
run_sm90_fp8_mha,
Expand Down Expand Up @@ -269,6 +273,7 @@ def get_xqa_module_mla(
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool = False,
enable_pdl: bool = True,
):
module = gen_xqa_module_mla(
input_dtype,
Expand All @@ -277,10 +282,11 @@ def get_xqa_module_mla(
head_dim,
head_group_ratio,
use_sliding_window,
enable_pdl,
).build_and_load()

@register_custom_op(
f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
mutates_args=("output", "workspace_buffer"),
)
def xqa_mla(
Expand Down Expand Up @@ -315,7 +321,7 @@ def xqa_mla(
)

@register_fake_op(
f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
f"flashinfer::xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}"
)
def _fake_xqa_mla(
sm_count: int,
Expand Down Expand Up @@ -352,6 +358,7 @@ def xqa_mla(
q_scale: float = 1.0,
kv_scale: Optional[torch.Tensor] = None,
sm_count: Optional[int] = None,
enable_pdl: bool = True,
) -> None:
r"""Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel.
Parameters
Expand Down Expand Up @@ -431,6 +438,7 @@ def xqa_mla(
head_dim,
head_group_ratio,
False,
enable_pdl,
)
xqa_module.xqa_mla(
sm_count,
Expand Down
Loading