Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions csrc/flashinfer_xqa_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
TensorView seqLen, int64_t batchSize, double kvCacheScale,
tvm::ffi::Optional<TensorView> kvScaleTensor,
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch, bool enable_pdl);
tvm::ffi::Optional<TensorView> kvScaleTensor, int64_t qSeqLen,
tvm::ffi::Optional<TensorView> mask, TensorView semaphores, TensorView scratch,
bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);

Expand Down
27 changes: 27 additions & 0 deletions csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,23 @@ __device__ inline void addAttentionSinks(ThrdRegRowMax& globalRowSum,
}
}

#if SPEC_DEC
// SPEC_DEC version: handles head-token mixed layout
__device__ inline void addAttentionSinksSpecDec(ThrdRegRowMax& globalRowSum,
ThrdRegRowMax const globalRowMax,
float const* attentionSinks, uint32_t headGrpSize) {
for (uint32_t i = 0; i < globalRowSum.size; i++) {
uint32_t idxHeadToken = warp_size * i + laneId();
// In SPEC_DEC, layout is [token0_head0, token0_head1, ..., token1_head0, ...]
// Extract head index from head-token index
uint32_t headIdx = idxHeadToken % headGrpSize;
if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check headIdx < headGrpSize is redundant. headIdx is calculated as idxHeadToken % headGrpSize. The result of the modulo operator % with a positive headGrpSize will always be in the range [0, headGrpSize - 1]. Removing this redundant check simplifies the code.

    if (idxHeadToken < rowsPerBlock) {

globalRowSum[i] += expf(attentionSinks[headIdx] - globalRowMax[i]);
}
}
}
#endif

#ifdef NDEBUG
__device__ __forceinline__
#else
Expand Down Expand Up @@ -2169,7 +2186,12 @@ CUBIN_EXPORT __global__
// enabled.
if (!isMultiBlock && attentionSinks != nullptr) {
// Attention sinks are per head.
#if SPEC_DEC
addAttentionSinksSpecDec(globalRowSum, globalRowMax,
attentionSinks + headGrpSize * idxHeadGrp, headGrpSize);
#else
addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp);
#endif
}
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
#if LOW_PREC_OUTPUT
Expand Down Expand Up @@ -2349,7 +2371,12 @@ CUBIN_EXPORT __global__
}
if (attentionSinks != nullptr) {
// Attention sinks are per head.
#if SPEC_DEC
addAttentionSinksSpecDec(mergedRowSum, mergedRowMax,
attentionSinks + headGrpSize * idxHeadGrp, headGrpSize);
#else
addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp);
#endif
}
__syncthreads();
rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum));
Expand Down
18 changes: 12 additions & 6 deletions csrc/xqa/xqa_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
Optional<TensorView> attentionSinks, TensorView kCacheVLLM, TensorView vCacheVLLM,
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
int64_t batchSize, double kvCacheScale, Optional<TensorView> kvScaleTensor,
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch, bool enable_pdl) {
int64_t qSeqLen, Optional<TensorView> mask, TensorView semaphores,
TensorView scratch, bool enable_pdl) {
auto stream = get_stream(output.device());
float const* attentionSinksPtr =
attentionSinks.has_value() ? reinterpret_cast<float const*>(attentionSinks.value().data_ptr())
Expand All @@ -70,13 +68,22 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
float const* kvScalePtr = kvScaleTensor.has_value()
? reinterpret_cast<float const*>(kvScaleTensor.value().data_ptr())
: nullptr;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900
auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
#else
auto const mha_func = &launchMHAFlashInfer;
#endif

// Extract strides from TensorView (in elements, not bytes)
uint64_t kv_stride_page = kCacheVLLM.stride(0);
uint64_t kv_stride_token = kCacheVLLM.stride(-3);
uint64_t kv_stride_head = kCacheVLLM.stride(-2);

#if SPEC_DEC
MaskType const* maskPtr =
mask.has_value() ? reinterpret_cast<MaskType const*>(mask.value().data_ptr()) : nullptr;
#endif

mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, qScalePtr,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
Expand All @@ -89,8 +96,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, kvCacheScale,
kvScalePtr,
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, nullptr, maskPtr,
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token,
Expand Down
16 changes: 14 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,7 @@ def trtllm_batch_decode_with_kv_cache(
backend: str = "auto",
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -2149,6 +2150,9 @@ def trtllm_batch_decode_with_kv_cache(
o_scale : Optional[float] = 1.0
output scale factor for xqa fp8 output.

mask : Optional[torch.Tensor] = None
causal attention mask for xqa speculative decoding.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -2204,6 +2208,7 @@ def trtllm_batch_decode_with_kv_cache(
enable_pdl=enable_pdl,
q_len_per_req=q_len_per_req,
o_scale=o_scale,
mask=mask,
)
elif backend == "trtllm-gen":
# Convert NHD layout to HND if necessary (transpose only changes stride, not data)
Expand Down Expand Up @@ -2348,6 +2353,7 @@ def xqa_batch_decode_with_kv_cache(
enable_pdl: bool = None,
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Parameters
Expand Down Expand Up @@ -2399,15 +2405,16 @@ def xqa_batch_decode_with_kv_cache(
o_scale : Optional[float] = 1.0
output scale factor for fp8 output.

mask : Optional[torch.Tensor] = None
causal attention mask for xqa speculative decoding.

Returns
-------
out : torch.Tensor
output torch.Tensor.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

assert q_len_per_req == 1, "xqa not support speculative decoding yet"

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
Expand Down Expand Up @@ -2441,6 +2448,9 @@ def xqa_batch_decode_with_kv_cache(
kv_scale_value = bmm2_scale * o_scale
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)

if q_len_per_req > 1:
batch_size = query.shape[0] // q_len_per_req
query = query.view(batch_size, q_len_per_req, query.shape[1], query.shape[2])
query_new = query.unsqueeze(1)
seq_lens_new = seq_lens.unsqueeze(1)
sinks_new = sinks.reshape(num_kv_heads, -1) if sinks is not None else None
Expand Down Expand Up @@ -2469,6 +2479,8 @@ def xqa_batch_decode_with_kv_cache(
sm_count=sm_count,
enable_pdl=enable_pdl,
rcp_out_scale=1.0 / o_scale,
q_seq_len=q_len_per_req,
mask=mask,
)

return out
Expand Down
36 changes: 27 additions & 9 deletions flashinfer/jit/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"-DBEAM_WIDTH=1",
"-DUSE_INPUT_KV=0",
"-DUSE_CUSTOM_BARRIER=1",
"-DSPEC_DEC=0",
]


Expand All @@ -40,6 +39,7 @@ def gen_xqa_module(
head_group_ratio: int,
use_sliding_window: bool,
output_dtype: torch.dtype,
q_seq_len: int = 1,
) -> JitSpec:
if input_dtype == torch.float16:
flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
Expand Down Expand Up @@ -81,6 +81,16 @@ def gen_xqa_module(
else:
flag_low_prec_output = ["-DLOW_PREC_OUTPUT=0"]

if q_seq_len > 1:
use_spec_dec = True
if q_seq_len * head_group_ratio <= 32:
flag_spec_dec = ["-DSPEC_DEC=1", f"-DSPEC_Q_SEQ_LEN={q_seq_len}"]
else:
flag_spec_dec = ["-DSPEC_DEC=1"]
else:
flag_spec_dec = ["-DSPEC_DEC=0"]
use_spec_dec = False

compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(
supported_major_versions=[9, 10, 11, 12]
Expand All @@ -89,15 +99,22 @@ def gen_xqa_module(

flag_mla_wrapper = ["-DMLA_WRAPPER=0"]

sources = [
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu",
]

target_archs = compilation_context.TARGET_CUDA_ARCHS

has_sm90 = any(major == 9 for major, minor in target_archs)
if has_sm90:
sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu")
sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp")

return gen_jit_spec(
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
[
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp",
jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu",
],
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}",
sources,
extra_cuda_cflags=xqa_nvcc_flags
+ sm_nvcc_flags
+ flag_tokens_per_page
Expand All @@ -107,6 +124,7 @@ def gen_xqa_module(
+ flag_head_group_ratio
+ flag_sliding_window
+ flag_low_prec_output
+ flag_spec_dec
+ flag_mla_wrapper,
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
)
Expand Down
45 changes: 39 additions & 6 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_xqa_module(
head_group_ratio: int,
use_sliding_window: bool,
output_dtype: torch.dtype,
q_seq_len: int,
):
module = gen_xqa_module(
input_dtype,
Expand All @@ -48,10 +49,16 @@ def get_xqa_module(
head_group_ratio,
use_sliding_window,
output_dtype,
q_seq_len,
).build_and_load()

if q_seq_len > 1:
use_spec_dec = True
else:
use_spec_dec = False

@register_custom_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}",
mutates_args=("output", "workspace_buffer"),
)
def xqa(
Expand All @@ -74,6 +81,8 @@ def xqa(
semaphores: torch.Tensor,
workspace_buffer: torch.Tensor,
enable_pdl: bool,
q_seq_len: int,
mask: Optional[torch.Tensor],
) -> None:
module.xqa_wrapper(
run_sm90_fp8_mha,
Expand All @@ -94,13 +103,15 @@ def xqa(
batch_size,
1.0 if isinstance(kv_scale, torch.Tensor) else kv_scale,
None if isinstance(kv_scale, float) else kv_scale,
q_seq_len,
mask,
semaphores,
workspace_buffer,
enable_pdl,
)

@register_fake_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}"
)
def _fake_xqa(
run_sm90_fp8_mha: bool,
Expand All @@ -122,6 +133,8 @@ def _fake_xqa(
semaphores: torch.Tensor,
workspace_buffer: torch.Tensor,
enable_pdl: bool,
q_seq_len: int,
mask: Optional[torch.Tensor],
) -> None:
pass

Expand Down Expand Up @@ -149,12 +162,15 @@ def xqa(
sm_count: Optional[int] = None,
enable_pdl: Optional[bool] = None,
rcp_out_scale: float = 1.0,
q_seq_len: int = 1,
mask: Optional[torch.Tensor] = None,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.
Parameters
----------
q : torch.Tensor
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]`` if not using speculative decoding,
or ``[batch_size, beam_width, q_seq_len, num_q_heads, head_dim]`` if using speculative decoding. ``q_seq_len`` is the number of speculative decoding tokens.
Data type should be torch.float16 or torch.bfloat16.
Now only beam_width 1 is supported.
k_cache: torch.Tensor
Expand All @@ -175,7 +191,7 @@ def xqa(
Sequence lengths tensor with shape ``[batch_size, beam_width]``.
Data type should be torch.uint32.
output : torch.Tensor
Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Output tensor with shape that matches the query tensor.
Data type should match query tensor or kv tensor. This tensor will be modified in-place.
workspace_buffer : torch.Tensor
Workspace buffer for temporary computations.
Expand Down Expand Up @@ -207,12 +223,19 @@ def xqa(
If None, will be set to True if hardware supports it.
rcp_out_scale : float, default=1.0
Reciprocal of output scale factor.
q_seq_len : int, default=1
Query sequence length. When > 1, enables speculative decoding mode.
mask : Optional[torch.Tensor], default=None
Causal attention mask for speculative decoding mode (when ``q_seq_len > 1``).
Shape: ``[batch_size, q_seq_len, mask_size_per_row]`` where
``mask_size_per_row = ((q_seq_len + 31) // 32) * 2``.
Data type should be torch.uint16 (bit-packed format, aligned to 32 bits).

Note
----
The function automatically infers several parameters from tensor shapes:
- batch_size from q.shape[0]
- num_q_heads from q.shape[2]
- num_q_heads from q.shape[-2]
- head_dim from q.shape[-1]
- input_dtype from q.dtype
- kv_cache_dtype from k.dtype
Expand All @@ -227,7 +250,7 @@ def xqa(

# Infer parameters from tensors
batch_size = q.shape[0]
num_q_heads = q.shape[2]
num_q_heads = q.shape[-2]
head_dim = q.shape[-1]

# Calculate head_group_ratio
Expand Down Expand Up @@ -274,7 +297,15 @@ def xqa(
head_group_ratio,
use_sliding_window,
output.dtype,
q_seq_len,
)

if q_seq_len > 1:
assert mask is not None, "Mask is required for speculative decoding"
run_sm90_fp8_mha = (
False # TODO: mha_sm90.cu has precision issue with speculative decoding
)

xqa_module.xqa(
run_sm90_fp8_mha,
sm_count,
Expand All @@ -295,6 +326,8 @@ def xqa(
semaphores,
workspace_buffer,
enable_pdl,
q_seq_len,
mask,
)


Expand Down
Loading