Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/flashinfer_xqa_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
#endif
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
TensorView scratch);
TensorView scratch, bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);

Expand All @@ -47,7 +47,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch);
TensorView semaphores, TensorView scratch, bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);

Expand Down
31 changes: 22 additions & 9 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
TensorView seq_lens, int64_t max_kv_len, double bmm1_scale,
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
bool enable_pdl, int64_t workspace_size,
int64_t kv_layout, bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
Expand Down Expand Up @@ -234,9 +234,16 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;

int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch
int kv_stride_keys_values, kv_stride_heads;
if (kv_layout == 0) { // nhd
kv_stride_keys_values = key_cache.stride(-3); // key/values
kv_stride_heads = key_cache.stride(-2); // head
} else { // kv_layout == 1, hnd
kv_stride_keys_values = key_cache.stride(-2); // key/values
kv_stride_heads = key_cache.stride(-3); // head
}

int kv_stride_batch = key_cache.stride(0); // batch

const auto stream = get_stream(query.device());
void* output_sf_ptr =
Expand Down Expand Up @@ -270,8 +277,8 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks) {
int64_t sm_count, int64_t kv_layout, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand All @@ -294,9 +301,15 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);

int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch
int kv_stride_keys_values, kv_stride_heads;
if (kv_layout == 0) { // nhd
kv_stride_keys_values = key_cache.stride(-3); // key/values
kv_stride_heads = key_cache.stride(-2); // head
} else { // kv_layout == 1, hnd
kv_stride_keys_values = key_cache.stride(-2); // key/values
kv_stride_heads = key_cache.stride(-3); // head
}
int kv_stride_batch = key_cache.stride(0); // batch

const auto stream = get_stream(query.device());
void* output_sf_ptr =
Expand Down
9 changes: 9 additions & 0 deletions csrc/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
#ifndef ENABLE_PDL
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#if __CUDA_ARCH__ == 900
#define ENABLE_PDL 2
#else
#define ENABLE_PDL 1
#endif
#else
/* default for host or older architectures */
#define ENABLE_PDL 0
#endif
#endif
Comment on lines 116 to 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Runtime-vs-compile-time PDL mismatch

ENABLE_PDL defaults to 1/2 for SMβ‰₯90 at compile time, but kernels still use this macro while the host now passes a runtime enable_pdl. Host compilation sees ENABLE_PDL==0 (no CUDA_ARCH), so kernels may execute preExit/acqBulk even when enable_pdl=false at launch. This is inconsistent and can lead to invalid usage of programmatic stream serialization.

Thread enable_pdl into the kernels and guard PDL intrinsics with the runtime flag (while keeping the arch guards). See follow-up diffs in kernel files below.

πŸ€– Prompt for AI Agents
In csrc/xqa/defines.h around lines 131-142, the current ENABLE_PDL macro
selection based solely on __CUDA_ARCH__ causes kernels to unconditionally
compile PDL intrinsics even though the host uses a runtime enable_pdl flag;
update the headers and usages so that kernels still respect the arch guards but
also check a passed-in runtime boolean (e.g., enable_pdl) before invoking PDL
intrinsics: keep the existing __CUDA_ARCH__ checks to determine PDL availability
at compile time, expose a runtime enable_pdl parameter into kernels (thread it
through kernel arguments or capture in device lambdas), and wrap all calls to
preExit/acqBulk (and other PDL intrinsics) with a combined condition that
requires both compile-time availability and the runtime flag so that when
enable_pdl is false at launch no PDL intrinsics execute on device.


#ifndef USE_INPUT_KV
Expand Down
11 changes: 6 additions & 5 deletions csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ constexpr uint32_t cvtExpansion = exactDiv(inputElemSize, cacheElemSize);
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#else
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \
Expand Down Expand Up @@ -2548,7 +2548,7 @@ void launchMHA(
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
#if SPEC_DEC
auto const qSeqLen = specDecParams.qSeqLen;
auto const qCuSeqLens = specDecParams.qCuSeqLens;
Expand Down Expand Up @@ -2590,7 +2590,7 @@ void launchMHA(
dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize};
#endif
dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
#if PAGED_KV_CACHE_LAYOUT == 1
Expand Down Expand Up @@ -2681,7 +2681,8 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
#if SPEC_DEC
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl,
cudaStream_t stream) {
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
if (!allowMultiBlockMode) {
return 1;
Expand All @@ -2696,7 +2697,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize};
#endif
dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
#if PAGED_KV_CACHE_LAYOUT == 1
Expand Down
13 changes: 7 additions & 6 deletions csrc/xqa/mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void launchMHA(
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);

void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
float qScale, OutputHead* output,
Expand All @@ -147,7 +147,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
#if SPEC_DEC
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);

void launchHopperF8MHA(
cudaDeviceProp const& prop, uint32_t nbKHeads,
Expand Down Expand Up @@ -189,7 +189,7 @@ void launchHopperF8MHA(
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);

void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads,
uint32_t slidingWinSize, float qScale, OutputHead* output,
Expand All @@ -208,7 +208,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
#if SPEC_DEC
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl,
cudaStream_t stream);

void launchMLA(
cudaDeviceProp const& prop,
Expand All @@ -230,7 +231,7 @@ void launchMLA(
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
// Used only for int8/fp8 KV cache.
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);

void launchMLAFlashInfer(
uint32_t multiProcessorCount,
Expand All @@ -248,7 +249,7 @@ void launchMLAFlashInfer(
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
// Used only for int8/fp8 KV cache.
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);

#if STATIC_NB_K_HEADS
constexpr uint32_t nbKHeads = NB_K_HEADS;
Expand Down
18 changes: 11 additions & 7 deletions csrc/xqa/mha_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1966,9 +1966,12 @@ __device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gme
for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) {
static_assert(nbThrdsPerInstNBase * RegColWiseVec::size ==
exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols));
ret[i] = reinterpret_cast<Vec<Vec<float, GmmaAccCoreMat::cols>,
exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>(
gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)];
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
Comment on lines +1877 to 1883
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Out‑of‑bounds read in loadGmemColWiseVecWithDup for attention sinks

gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols.

Apply this fix:

-    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
-    uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
-#pragma unroll
-    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
-      ret[i][j] = gmemVec[baseOffset + j];
-    }
+    uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
+#pragma unroll
+    for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
+      // Duplicate the same head sink across the 2 columns
+      ret[i][j] = gmemVec[clampedIdx];
+    }
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
ret[i][j] = gmemVec[baseOffset + j];
}
}
uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
// Duplicate the same head sink across the 2 columns
ret[i][j] = gmemVec[clampedIdx];
}

return ret;
}
Expand Down Expand Up @@ -3033,7 +3036,7 @@ void launchHopperF8MHA(
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
if (beamWidth != 1) {
throw std::runtime_error("not implemented");
}
Expand Down Expand Up @@ -3070,7 +3073,7 @@ void launchHopperF8MHA(
// nbInputSeqSplit
dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
auto const dtype = [] {
Expand Down Expand Up @@ -3191,7 +3194,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
#if SPEC_DEC
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl,
cudaStream_t stream) {
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
float const factor = 0.25f;
return mha::min<uint32_t>(
Expand All @@ -3207,7 +3211,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
#endif
dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
auto const dtype = [] {
Expand Down
8 changes: 4 additions & 4 deletions csrc/xqa/mla_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1724,7 +1724,7 @@ void launchMLA(
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
// Used only for int8/fp8 KV cache.
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
#if IS_MLA
static_assert(
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
Expand Down Expand Up @@ -1762,7 +1762,7 @@ void launchMLA(
// nbInputSeqSplit
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
#if PAGED_KV_CACHE_LAYOUT == 1
Expand Down Expand Up @@ -1861,7 +1861,7 @@ void launchMLAFlashInfer(
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
// Used only for int8/fp8 KV cache.
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
#if IS_MLA
static_assert(
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
Expand All @@ -1885,7 +1885,7 @@ void launchMLAFlashInfer(
// nbInputSeqSplit
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
#if PAGED_KV_CACHE_LAYOUT == 1
Expand Down
2 changes: 1 addition & 1 deletion csrc/xqa/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inline constexpr int32_t kBAD_PAGE_INDEX = -1;
__constant__ constexpr float kE4M3_MAX = 448.F;

#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10);
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870
constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10);
Expand Down
8 changes: 4 additions & 4 deletions csrc/xqa/xqa_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
#endif
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
TensorView scratch) {
TensorView scratch, bool enable_pdl) {
auto stream = get_stream(output.device());

launchMLAFlashInfer(multiProcessorCount, 1, qScale,
Expand All @@ -44,7 +44,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, stream);
}
#else

Expand All @@ -64,7 +64,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
#if SPEC_DEC
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
#endif
TensorView semaphores, TensorView scratch) {
TensorView semaphores, TensorView scratch, bool enable_pdl) {
auto stream = get_stream(output.device());
float const* attentionSinksPtr =
attentionSinks.has_value() ? reinterpret_cast<float const*>(attentionSinks.value().data_ptr())
Expand All @@ -91,6 +91,6 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, stream);
}
#endif
Loading