Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions csrc/flashinfer_xqa_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@

#if MLA_WRAPPER
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
#if PAGED_KV_CACHE_LAYOUT == 1
TensorView kCacheVLLM, TensorView vCacheVLLM,
#else
TensorView pool,
#endif
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
TensorView scratch);
TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList,
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);

Expand All @@ -36,18 +32,13 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
#if LOW_PREC_OUTPUT
TensorView rcpOutScale,
#endif
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks,
#if PAGED_KV_CACHE_LAYOUT == 1
TensorView kCacheVLLM, TensorView vCacheVLLM,
#else
TensorView pool,
#endif
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
int64_t batchSize, TensorView kvCacheScale,
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch);
TensorView semaphores, TensorView scratch, bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);

Expand Down
11 changes: 7 additions & 4 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,17 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || head_dim_v == head_dim_o)
<< "head_dim_v and head_dim_o must be the same for non-MLA attention, got "
<< std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o);
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);
int max_num_blocks_per_seq = block_tables.size(-1);
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;

// Assume NHD layout: [..., H, N, D]
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch

int kv_stride_batch = key_cache.stride(0); // batch

const auto stream = get_stream(query.device());
void* output_sf_ptr =
Expand Down Expand Up @@ -291,9 +293,10 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
int max_num_blocks_per_seq = block_tables.size(-1);
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;

// Assume NHD layout: [..., H, N, D]
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);

int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch
Expand Down
27 changes: 10 additions & 17 deletions csrc/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,6 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define TOKENS_PER_PAGE 32
#endif

// don't modify
#ifndef USE_PAGED_KV_CACHE
#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0)
#endif

// Paged KV Cache Format
// 0 - XQA Original
// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for
// VLLM/SGLang
#ifdef USE_PAGED_KV_CACHE
#ifndef PAGED_KV_CACHE_LAYOUT
#define PAGED_KV_CACHE_LAYOUT 0
#endif
#endif

// don't modify
#define USE_BEAM_SEARCH (BEAM_WIDTH > 1)

Expand All @@ -129,7 +114,16 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
#ifndef ENABLE_PDL
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#if __CUDA_ARCH__ == 900
#define ENABLE_PDL 2
#else
#define ENABLE_PDL 1
#endif
#else
/* default for host or older architectures */
#define ENABLE_PDL 0
#endif
#endif
Comment on lines 116 to 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Runtime-vs-compile-time PDL mismatch

ENABLE_PDL defaults to 1/2 for SMβ‰₯90 at compile time, but kernels still use this macro while the host now passes a runtime enable_pdl. Host compilation sees ENABLE_PDL==0 (no CUDA_ARCH), so kernels may execute preExit/acqBulk even when enable_pdl=false at launch. This is inconsistent and can lead to invalid usage of programmatic stream serialization.

Thread enable_pdl into the kernels and guard PDL intrinsics with the runtime flag (while keeping the arch guards). See follow-up diffs in kernel files below.

πŸ€– Prompt for AI Agents
In csrc/xqa/defines.h around lines 131-142, the current ENABLE_PDL macro
selection based solely on __CUDA_ARCH__ causes kernels to unconditionally
compile PDL intrinsics even though the host uses a runtime enable_pdl flag;
update the headers and usages so that kernels still respect the arch guards but
also check a passed-in runtime boolean (e.g., enable_pdl) before invoking PDL
intrinsics: keep the existing __CUDA_ARCH__ checks to determine PDL availability
at compile time, expose a runtime enable_pdl parameter into kernels (thread it
through kernel arguments or capture in device lambdas), and wrap all calls to
preExit/acqBulk (and other PDL intrinsics) with a combined condition that
requires both compile-time availability and the runtime flag so that when
enable_pdl is false at launch no PDL intrinsics execute on device.


#ifndef USE_INPUT_KV
Expand Down Expand Up @@ -161,8 +155,7 @@ static_assert(CACHE_ELEM_ENUM != 0);
#endif

// true should be better if warpTile.x * cacheElemSize < 128. otherwise use false.
#define GRP_LOAD_V \
(CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && USE_PAGED_KV_CACHE && BEAM_WIDTH > 1)
#define GRP_LOAD_V (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && BEAM_WIDTH > 1)

// use custom barrier for NVRTC to avoid pulling in many headers
#ifndef USE_CUSTOM_BARRIER
Expand Down
Loading