Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions csrc/flashinfer_xqa_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale,
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
TensorView seqLen, int64_t batchSize, double kvCacheScale,
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch, bool enable_pdl);
TensorView seqLen, int64_t batchSize, double kvCacheScale, int64_t qSeqLen,
tvm::ffi::Optional<TensorView> mask, TensorView semaphores, TensorView scratch,
bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);

Expand Down
27 changes: 27 additions & 0 deletions csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,23 @@ __device__ inline void addAttentionSinks(ThrdRegRowMax& globalRowSum,
}
}

#if SPEC_DEC
// SPEC_DEC version: handles head-token mixed layout
__device__ inline void addAttentionSinksSpecDec(ThrdRegRowMax& globalRowSum,
ThrdRegRowMax const globalRowMax,
float const* attentionSinks, uint32_t headGrpSize) {
for (uint32_t i = 0; i < globalRowSum.size; i++) {
uint32_t idxHeadToken = warp_size * i + laneId();
// In SPEC_DEC, layout is [token0_head0, token0_head1, ..., token1_head0, ...]
// Extract head index from head-token index
uint32_t headIdx = idxHeadToken % headGrpSize;
if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check headIdx < headGrpSize is redundant. headIdx is calculated as idxHeadToken % headGrpSize. The result of the modulo operator % with a positive headGrpSize will always be in the range [0, headGrpSize - 1]. Removing this redundant check simplifies the code.

    if (idxHeadToken < rowsPerBlock) {

globalRowSum[i] += expf(attentionSinks[headIdx] - globalRowMax[i]);
}
}
}
#endif

#ifdef NDEBUG
__device__ __forceinline__
#else
Expand Down Expand Up @@ -2161,7 +2178,12 @@ CUBIN_EXPORT __global__
// enabled.
if (!isMultiBlock && attentionSinks != nullptr) {
// Attention sinks are per head.
#if SPEC_DEC
addAttentionSinksSpecDec(globalRowSum, globalRowMax,
attentionSinks + headGrpSize * idxHeadGrp, headGrpSize);
#else
addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp);
#endif
}
ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum);
#if LOW_PREC_OUTPUT
Expand Down Expand Up @@ -2341,7 +2363,12 @@ CUBIN_EXPORT __global__
}
if (attentionSinks != nullptr) {
// Attention sinks are per head.
#if SPEC_DEC
addAttentionSinksSpecDec(mergedRowSum, mergedRowMax,
attentionSinks + headGrpSize * idxHeadGrp, headGrpSize);
#else
addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp);
#endif
}
__syncthreads();
rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum));
Expand Down
20 changes: 13 additions & 7 deletions csrc/xqa/xqa_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,29 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
int64_t slidingWinSize, double qScale, TensorView output, double rcpOutScale,
TensorView q, Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
TensorView seqLen, int64_t batchSize, double kvCacheScale,
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch, bool enable_pdl) {
TensorView seqLen, int64_t batchSize, double kvCacheScale, int64_t qSeqLen,
Optional<TensorView> mask, TensorView semaphores, TensorView scratch,
bool enable_pdl) {
auto stream = get_stream(output.device());
float const* attentionSinksPtr =
attentionSinks.has_value() ? reinterpret_cast<float const*>(attentionSinks.value().data_ptr())
: nullptr;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900
auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
#else
auto const mha_func = &launchMHAFlashInfer;
#endif

// Extract strides from TensorView (in elements, not bytes)
uint64_t kv_stride_page = kCacheVLLM.stride(0);
uint64_t kv_stride_token = kCacheVLLM.stride(-3);
uint64_t kv_stride_head = kCacheVLLM.stride(-2);

#if SPEC_DEC
MaskType const* maskPtr =
mask.has_value() ? reinterpret_cast<MaskType const*>(mask.value().data_ptr()) : nullptr;
#endif

mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
Expand All @@ -75,8 +82,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, kvCacheScale,
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, nullptr, maskPtr,
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, kv_stride_page, kv_stride_token,
Expand Down
16 changes: 14 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,7 @@ def trtllm_batch_decode_with_kv_cache(
backend: str = "auto",
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -2149,6 +2150,9 @@ def trtllm_batch_decode_with_kv_cache(
o_scale : Optional[float] = 1.0
output scale factor for xqa fp8 output.

mask : Optional[torch.Tensor] = None
causal attention mask for xqa speculative decoding.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -2209,6 +2213,7 @@ def trtllm_batch_decode_with_kv_cache(
enable_pdl=enable_pdl,
q_len_per_req=q_len_per_req,
o_scale=o_scale,
mask=mask,
)
elif backend == "trtllm-gen":
# Convert NHD layout to HND if necessary (transpose only changes stride, not data)
Expand Down Expand Up @@ -2353,6 +2358,7 @@ def xqa_batch_decode_with_kv_cache(
enable_pdl: bool = None,
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Parameters
Expand Down Expand Up @@ -2404,15 +2410,16 @@ def xqa_batch_decode_with_kv_cache(
o_scale : Optional[float] = 1.0
output scale factor for fp8 output.

mask : Optional[torch.Tensor] = None
causal attention mask for xqa speculative decoding.

Returns
-------
out : torch.Tensor
output torch.Tensor.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

assert q_len_per_req == 1, "xqa not support speculative decoding yet"

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
else:
Expand Down Expand Up @@ -2453,6 +2460,9 @@ def xqa_batch_decode_with_kv_cache(
kv_scale_value = bmm2_scale * o_scale
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)

if q_len_per_req > 1:
batch_size = query.shape[0] // q_len_per_req
query = query.view(batch_size, q_len_per_req, query.shape[1], query.shape[2])
query_new = query.unsqueeze(1)
seq_lens_new = seq_lens.unsqueeze(1)
sinks_new = sinks.reshape(num_kv_heads, -1) if sinks is not None else None
Expand Down Expand Up @@ -2481,6 +2491,8 @@ def xqa_batch_decode_with_kv_cache(
sm_count=sm_count,
enable_pdl=enable_pdl,
rcp_out_scale=1.0 / o_scale,
q_seq_len=q_len_per_req,
mask=mask,
)

return out
Expand Down
36 changes: 27 additions & 9 deletions flashinfer/jit/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"-DBEAM_WIDTH=1",
"-DUSE_INPUT_KV=0",
"-DUSE_CUSTOM_BARRIER=1",
"-DSPEC_DEC=0",
]


Expand All @@ -40,6 +39,7 @@ def gen_xqa_module(
head_group_ratio: int,
use_sliding_window: bool,
output_dtype: torch.dtype,
q_seq_len: int = 1,
) -> JitSpec:
if input_dtype == torch.float16:
flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
Expand Down Expand Up @@ -81,6 +81,16 @@ def gen_xqa_module(
else:
flag_low_prec_output = ["-DLOW_PREC_OUTPUT=0"]

if q_seq_len > 1:
use_spec_dec = True
if q_seq_len * head_group_ratio <= 32:
flag_spec_dec = ["-DSPEC_DEC=1", f"-DSPEC_Q_SEQ_LEN={q_seq_len}"]
else:
flag_spec_dec = ["-DSPEC_DEC=1"]
else:
flag_spec_dec = ["-DSPEC_DEC=0"]
use_spec_dec = False

compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(
supported_major_versions=[9, 10, 11, 12]
Expand All @@ -89,15 +99,22 @@ def gen_xqa_module(

flag_mla_wrapper = ["-DMLA_WRAPPER=0"]

sources = [
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu",
]

target_archs = compilation_context.TARGET_CUDA_ARCHS

has_sm90 = any(major == 9 for major, minor in target_archs)
if has_sm90:
sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu")
sources.append(jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp")

return gen_jit_spec(
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
[
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",
jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp",
jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_binding.cu",
],
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}",
sources,
extra_cuda_cflags=xqa_nvcc_flags
+ sm_nvcc_flags
+ flag_tokens_per_page
Expand All @@ -107,6 +124,7 @@ def gen_xqa_module(
+ flag_head_group_ratio
+ flag_sliding_window
+ flag_low_prec_output
+ flag_spec_dec
+ flag_mla_wrapper,
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
)
Expand Down
46 changes: 40 additions & 6 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_xqa_module(
head_group_ratio: int,
use_sliding_window: bool,
output_dtype: torch.dtype,
q_seq_len: int,
):
module = gen_xqa_module(
input_dtype,
Expand All @@ -48,10 +49,16 @@ def get_xqa_module(
head_group_ratio,
use_sliding_window,
output_dtype,
q_seq_len,
).build_and_load()

if q_seq_len > 1:
use_spec_dec = True
else:
use_spec_dec = False

@register_custom_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}",
mutates_args=("output", "workspace_buffer"),
)
def xqa(
Expand All @@ -74,6 +81,8 @@ def xqa(
semaphores: torch.Tensor,
workspace_buffer: torch.Tensor,
enable_pdl: bool,
q_seq_len: int,
mask: Optional[torch.Tensor],
) -> None:
module.xqa_wrapper(
run_sm90_fp8_mha,
Expand All @@ -92,13 +101,15 @@ def xqa(
seq_lens,
batch_size,
kv_scale,
q_seq_len,
mask,
semaphores,
workspace_buffer,
enable_pdl,
)

@register_fake_op(
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}"
f"flashinfer::xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_output_{filename_safe_dtype_map[output_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_use_spec_dec_{use_spec_dec}"
)
def _fake_xqa(
run_sm90_fp8_mha: bool,
Expand All @@ -119,6 +130,9 @@ def _fake_xqa(
kv_scale: float,
semaphores: torch.Tensor,
workspace_buffer: torch.Tensor,
enable_pdl: bool,
q_seq_len: int,
mask: Optional[torch.Tensor],
) -> None:
pass

Expand Down Expand Up @@ -146,12 +160,15 @@ def xqa(
sm_count: Optional[int] = None,
enable_pdl: Optional[bool] = None,
rcp_out_scale: float = 1.0,
q_seq_len: int = 1,
mask: Optional[torch.Tensor] = None,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.
Parameters
----------
q : torch.Tensor
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]`` if not using speculative decoding,
or ``[batch_size, beam_width, q_seq_len, num_q_heads, head_dim]`` if using speculative decoding. ``q_seq_len`` is the number of speculative decoding tokens.
Data type should be torch.float16 or torch.bfloat16.
Now only beam_width 1 is supported.
k_cache: torch.Tensor
Expand All @@ -172,7 +189,7 @@ def xqa(
Sequence lengths tensor with shape ``[batch_size, beam_width]``.
Data type should be torch.uint32.
output : torch.Tensor
Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
Output tensor with shape that matches the query tensor.
Data type should match query tensor or kv tensor. This tensor will be modified in-place.
workspace_buffer : torch.Tensor
Workspace buffer for temporary computations.
Expand Down Expand Up @@ -204,12 +221,19 @@ def xqa(
If None, will be set to True if hardware supports it.
rcp_out_scale : float, default=1.0
Reciprocal of output scale factor.
q_seq_len : int, default=1
Query sequence length. When > 1, enables speculative decoding mode.
mask : Optional[torch.Tensor], default=None
Causal attention mask for speculative decoding mode (when ``q_seq_len > 1``).
Shape: ``[batch_size, q_seq_len, mask_size_per_row]`` where
``mask_size_per_row = ((q_seq_len + 31) // 32) * 2``.
Data type should be torch.uint16 (bit-packed format, aligned to 32 bits).

Note
----
The function automatically infers several parameters from tensor shapes:
- batch_size from q.shape[0]
- num_q_heads from q.shape[2]
- num_q_heads from q.shape[-2]
- head_dim from q.shape[-1]
- input_dtype from q.dtype
- kv_cache_dtype from k.dtype
Expand All @@ -224,7 +248,7 @@ def xqa(

# Infer parameters from tensors
batch_size = q.shape[0]
num_q_heads = q.shape[2]
num_q_heads = q.shape[-2]
head_dim = q.shape[-1]

# Calculate head_group_ratio
Expand Down Expand Up @@ -271,7 +295,15 @@ def xqa(
head_group_ratio,
use_sliding_window,
output.dtype,
q_seq_len,
)

if q_seq_len > 1:
assert mask is not None, "Mask is required for speculative decoding"
run_sm90_fp8_mha = (
False # TODO: mha_sm90.cu has precision issue with speculative decoding
)

xqa_module.xqa(
run_sm90_fp8_mha,
sm_count,
Expand All @@ -292,6 +324,8 @@ def xqa(
semaphores,
workspace_buffer,
enable_pdl,
q_seq_len,
mask,
)


Expand Down
Loading