Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api/attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ Batch Decoding

.. automethod:: __init__

XQA
^^^

.. currentmodule:: flashinfer.xqa

.. autosummary::
:toctree: ../generated

xqa


flashinfer.prefill
==================
Expand Down
270 changes: 181 additions & 89 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

import functools
from types import SimpleNamespace
from typing import Optional

import torch

from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec, sm90a_nvcc_flags
from .jit.utils import filename_safe_dtype_map
from .utils import (
get_device_sm_count,
register_custom_op,
register_fake_op,
)
Expand All @@ -38,38 +41,42 @@


def gen_xqa_module(
use_fp16: bool,
token_per_page: int,
head_size: int,
head_grp_size: int,
dtype: torch.dtype,
page_size: int,
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool,
) -> JitSpec:
if use_fp16:
flag_use_fp16 = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
if dtype == torch.float16:
flag_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
elif dtype == torch.bfloat16:
flag_dtype = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"]
else:
flag_use_fp16 = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"]
raise ValueError(
f"Invalid dtype: {dtype} for XQA, only float16 and bfloat16 are supported"
)

if token_per_page not in [16, 32, 64, 128]:
if page_size not in [16, 32, 64, 128]:
raise ValueError(
f"Invalid token_per_page: {token_per_page}, only 16, 32, 64, 128 are supported"
f"Invalid page_size: {page_size}, only 16, 32, 64, 128 are supported"
)
flag_tokens_per_page = [f"-DTOKENS_PER_PAGE={token_per_page}"]
flag_tokens_per_page = [f"-DTOKENS_PER_PAGE={page_size}"]

if head_size % 16 != 0 or head_size > 256 or head_size < 16:
if head_dim % 16 != 0 or head_dim > 256 or head_dim < 16:
raise ValueError(
f"Invalid head_size: {head_size}, must be divisible by 16 and in range [16, 256]"
f"Invalid head_dim: {head_dim}, must be divisible by 16 and in range [16, 256]"
)
flag_head_size = [f"-DHEAD_ELEMS={head_size}"]
flag_head_dim = [f"-DHEAD_ELEMS={head_dim}"]

flag_head_grp_size = [f"-DHEAD_GRP_SIZE={head_grp_size}"]
flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"]

if use_sliding_window:
flag_sliding_window = ["-DSLIDING_WINDOW=1"]
else:
flag_sliding_window = ["-DSLIDING_WINDOW=0"]

return gen_jit_spec(
f"xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}",
f"xqa_dtype_{filename_safe_dtype_map[dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
[
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu",
Expand All @@ -78,83 +85,83 @@ def gen_xqa_module(
extra_cuda_cflags=xqa_nvcc_flags
+ sm90a_nvcc_flags
+ flag_tokens_per_page
+ flag_head_size
+ flag_use_fp16
+ flag_head_grp_size
+ flag_head_dim
+ flag_dtype
+ flag_head_group_ratio
+ flag_sliding_window,
)


@functools.cache
def get_xqa_module(
use_fp16: bool,
token_per_page: int,
head_size: int,
head_grp_size: int,
dtype: torch.dtype,
page_size: int,
head_dim: int,
head_group_ratio: int,
use_sliding_window: bool,
):
module = gen_xqa_module(
use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window
dtype, page_size, head_dim, head_group_ratio, use_sliding_window
).build_and_load()

@register_custom_op(
f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}",
mutates_args=("output", "scratch"),
f"flashinfer::xqa_dtype_{filename_safe_dtype_map[dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
mutates_args=("output", "workspace_buffer"),
)
def xqa(
multiProcessorCount: int,
nbKHeads: int,
sm_count: int,
num_kv_heads: int,
slidingWinSize: int,
qScale: float,
q_scale: float,
output: torch.Tensor,
q: torch.Tensor,
attentionSinks: torch.Tensor,
pool: torch.Tensor,
kvCachePageList: torch.Tensor,
maxSeqLen: int,
seqLen: torch.Tensor,
batchSize: int,
kvCacheScale: torch.Tensor,
sinks: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the sinks parameter should be Optional[torch.Tensor] to match its usage. The public xqa function allows sinks to be None, and the C++ backend is designed to handle this. This type hint should be updated to reflect that it's an optional parameter to avoid confusion and potential issues with static analysis tools.

Suggested change
sinks: torch.Tensor,
sinks: Optional[torch.Tensor],

kv_cache: torch.Tensor,
page_table: torch.Tensor,
max_seq_len: int,
seq_lens: torch.Tensor,
batch_size: int,
kv_scale: torch.Tensor,
semaphores: torch.Tensor,
scratch: torch.Tensor,
workspace_buffer: torch.Tensor,
) -> None:
module.xqa_wrapper.default(
multiProcessorCount,
nbKHeads,
sm_count,
num_kv_heads,
slidingWinSize,
qScale,
q_scale,
output,
q,
attentionSinks,
pool,
kvCachePageList,
maxSeqLen,
seqLen,
batchSize,
kvCacheScale,
sinks,
kv_cache,
page_table,
max_seq_len,
seq_lens,
batch_size,
kv_scale,
semaphores,
scratch,
workspace_buffer,
)

@register_fake_op(
f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}"
f"flashinfer::xqa_dtype_{filename_safe_dtype_map[dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
)
def _fake_xqa(
multiProcessorCount: int,
nbKHeads: int,
sm_count: int,
num_kv_heads: int,
slidingWinSize: int,
qScale: float,
q_scale: float,
output: torch.Tensor,
q: torch.Tensor,
attentionSinks: torch.Tensor,
pool: torch.Tensor,
kvCachePageList: torch.Tensor,
maxSeqLen: int,
seqLen: torch.Tensor,
batchSize: int,
kvCacheScale: torch.Tensor,
sinks: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the xqa stub, the type hint for the sinks parameter in _fake_xqa should be Optional[torch.Tensor] to accurately reflect that it can be None.

Suggested change
sinks: torch.Tensor,
sinks: Optional[torch.Tensor],

kv_cache: torch.Tensor,
page_table: torch.Tensor,
max_seq_len: int,
seq_lens: torch.Tensor,
batch_size: int,
kv_scale: torch.Tensor,
semaphores: torch.Tensor,
scratch: torch.Tensor,
workspace_buffer: torch.Tensor,
) -> None:
pass

Expand All @@ -164,44 +171,129 @@ def _fake_xqa(


def xqa(
use_fp16: bool,
token_per_page: int,
head_size: int,
head_grp_size: int,
use_sliding_window: bool,
sliding_win_size: int,
multiProcessorCount: int,
nbKHeads: int,
qScale: float,
output: torch.Tensor,
q: torch.Tensor,
attentionSinks: torch.Tensor,
pool: torch.Tensor,
kvCachePageList: torch.Tensor,
maxSeqLen: int,
seqLen: torch.Tensor,
batchSize: int,
kvCacheScale: torch.Tensor,
kv_cache: torch.Tensor,
page_table: torch.Tensor,
seq_lens: torch.Tensor,
output: torch.Tensor,
workspace_buffer: torch.Tensor,
semaphores: torch.Tensor,
scratch: torch.Tensor,
num_kv_heads: int,
page_size: int,
sinks: Optional[torch.Tensor] = None,
q_scale: float = 1.0,
kv_scale: Optional[torch.Tensor] = None,
sliding_win_size: int = 0,
sm_count: Optional[int] = None,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.

Parameters
----------
q : torch.Tensor
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Data type should be torch.float16 or torch.bfloat16.

kv_cache : torch.Tensor
Paged KV cache tensor with shape ``[total_num_cache_heads, head_dim]``.
Contains both K and V cache data interleaved.
Data type should match query tensor.

page_table : torch.Tensor
Page table tensor with shape ``[batch_size, beam_width, 2, num_pages_per_seq]``.
Data type should be torch.uint32.
The third dimension represents K and V cache (0 for K, 1 for V).

seq_lens : torch.Tensor
Sequence lengths tensor with shape ``[batch_size, beam_width]``.
Data type should be torch.uint32.

output : torch.Tensor
Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Data type should match query tensor. This tensor will be modified in-place.

workspace_buffer : torch.Tensor
Workspace buffer for temporary computations.
Data type should be torch.uint8.

semaphores : torch.Tensor
Semaphore buffer for synchronization.
Data type should be torch.uint32.

num_kv_heads : int
Number of key-value heads in the attention mechanism.

page_size : int
Size of each page in the paged KV cache. Must be one of [16, 32, 64, 128].

sinks : Optional[torch.Tensor], default=None
Attention sink values with shape ``[num_kv_heads, head_group_ratio]``.
Data type should be torch.float32.
If None, no attention sinks are used.

q_scale : float, default=1.0
Scale factor for query tensor.

kv_scale : Optional[torch.Tensor], default=None
Scale factor for KV cache with shape ``[1]``.
Data type should be torch.float32.
If None, defaults to 1.0.

sliding_win_size : int, default=0
Sliding window size for attention. If 0, no sliding window is used.

sm_count : Optional[int], default=None
Number of streaming multiprocessors to use.
If None, will be inferred from the device.

Note
----
The function automatically infers several parameters from tensor shapes:
- batch_size from q.shape[0]
- num_q_heads from q.shape[2]
- head_dim from q.shape[-1]
- use_fp16 from q.dtype
- head_group_ratio from num_q_heads // num_kv_heads
"""
# Handle optional parameters
if sm_count is None:
sm_count = get_device_sm_count(q.device)

if kv_scale is None:
kv_scale = torch.ones(1, dtype=torch.float32, device=q.device)

# Infer parameters from tensors
batch_size = q.shape[0]
num_q_heads = q.shape[2]
head_dim = q.shape[-1]

# Calculate head_group_ratio
head_group_ratio = num_q_heads // num_kv_heads

# Calculate max_seq_len from page_table and page_size
num_pages_per_seq = page_table.shape[-1]
max_seq_len = num_pages_per_seq * page_size

# Determine if sliding window is used
use_sliding_window = sliding_win_size >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a logic error in determining use_sliding_window. According to the docstring, a sliding_win_size of 0 should disable the sliding window. However, with the current logic sliding_win_size >= 0, it will be enabled. This causes the kernel to be compiled with SLIDING_WINDOW=1 and when called with slidingWinSize=0, it incorrectly skips all tokens in the sequence. The condition should be sliding_win_size > 0 to correctly disable sliding window when the size is 0.

Suggested change
use_sliding_window = sliding_win_size >= 0
use_sliding_window = sliding_win_size > 0


xqa_module = get_xqa_module(
use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window
q.dtype, page_size, head_dim, head_group_ratio, use_sliding_window
)
xqa_module.xqa(
multiProcessorCount,
nbKHeads,
sm_count,
num_kv_heads,
sliding_win_size if use_sliding_window else 0,
qScale,
q_scale,
output,
q,
attentionSinks,
pool,
kvCachePageList,
maxSeqLen,
seqLen,
batchSize,
kvCacheScale,
sinks,
kv_cache,
page_table,
max_seq_len,
seq_lens,
batch_size,
kv_scale,
semaphores,
scratch,
workspace_buffer,
)
Loading
Loading