Skip to content

Commit 9c08d33

Browse files
committed
move enable_pdl to runtime
1 parent 595ee1b commit 9c08d33

File tree

10 files changed

+65
-55
lines changed

10 files changed

+65
-55
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
2525
#endif
2626
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
2727
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
28-
TensorView scratch);
28+
TensorView scratch, bool enable_pdl);
2929

3030
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
3131

@@ -47,7 +47,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
4747
#if SPEC_DEC
4848
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
4949
#endif
50-
TensorView semaphores, TensorView scratch);
50+
TensorView semaphores, TensorView scratch, bool enable_pdl);
5151

5252
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper);
5353

csrc/xqa/defines.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,16 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
129129
// 1 - naive PDL
130130
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
131131
#ifndef ENABLE_PDL
132+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
133+
#if __CUDA_ARCH__ == 900
132134
#define ENABLE_PDL 2
135+
#else
136+
#define ENABLE_PDL 1
137+
#endif
138+
#else
139+
/* default for host or older architectures */
140+
#define ENABLE_PDL 0
141+
#endif
133142
#endif
134143

135144
#ifndef USE_INPUT_KV

csrc/xqa/mha.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,7 +2548,7 @@ void launchMHA(
25482548
#if SPEC_DEC
25492549
SpecDecParams const& specDecParams,
25502550
#endif
2551-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
2551+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
25522552
#if SPEC_DEC
25532553
auto const qSeqLen = specDecParams.qSeqLen;
25542554
auto const qCuSeqLens = specDecParams.qCuSeqLens;
@@ -2590,7 +2590,7 @@ void launchMHA(
25902590
dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize};
25912591
#endif
25922592
dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z};
2593-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
2593+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
25942594
#if USE_PAGED_KV_CACHE
25952595
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
25962596
#if PAGED_KV_CACHE_LAYOUT == 1
@@ -2681,7 +2681,8 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
26812681
#if SPEC_DEC
26822682
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
26832683
#endif
2684-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
2684+
uint32_t* semaphores, void* scratch, bool enable_pdl,
2685+
cudaStream_t stream) {
26852686
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
26862687
if (!allowMultiBlockMode) {
26872688
return 1;
@@ -2696,7 +2697,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
26962697
dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize};
26972698
#endif
26982699
dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z};
2699-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
2700+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
27002701
#if USE_PAGED_KV_CACHE
27012702
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
27022703
#if PAGED_KV_CACHE_LAYOUT == 1

csrc/xqa/mha.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ void launchMHA(
128128
#if SPEC_DEC
129129
SpecDecParams const& specDecParams,
130130
#endif
131-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
131+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
132132

133133
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
134134
float qScale, OutputHead* output,
@@ -147,7 +147,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
147147
#if SPEC_DEC
148148
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
149149
#endif
150-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
150+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
151151

152152
void launchHopperF8MHA(
153153
cudaDeviceProp const& prop, uint32_t nbKHeads,
@@ -189,7 +189,7 @@ void launchHopperF8MHA(
189189
#if SPEC_DEC
190190
SpecDecParams const& specDecParams,
191191
#endif
192-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
192+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
193193

194194
void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads,
195195
uint32_t slidingWinSize, float qScale, OutputHead* output,
@@ -208,7 +208,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
208208
#if SPEC_DEC
209209
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
210210
#endif
211-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
211+
uint32_t* semaphores, void* scratch, bool enable_pdl,
212+
cudaStream_t stream);
212213

213214
void launchMLA(
214215
cudaDeviceProp const& prop,
@@ -230,7 +231,7 @@ void launchMLA(
230231
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
231232
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
232233
// Used only for int8/fp8 KV cache.
233-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
234+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
234235

235236
void launchMLAFlashInfer(
236237
uint32_t multiProcessorCount,
@@ -248,7 +249,7 @@ void launchMLAFlashInfer(
248249
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
249250
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
250251
// Used only for int8/fp8 KV cache.
251-
uint32_t* semaphores, void* scratch, cudaStream_t stream);
252+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
252253

253254
#if STATIC_NB_K_HEADS
254255
constexpr uint32_t nbKHeads = NB_K_HEADS;

csrc/xqa/mha_sm90.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3036,7 +3036,7 @@ void launchHopperF8MHA(
30363036
#if SPEC_DEC
30373037
SpecDecParams const& specDecParams,
30383038
#endif
3039-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
3039+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
30403040
if (beamWidth != 1) {
30413041
throw std::runtime_error("not implemented");
30423042
}
@@ -3073,7 +3073,7 @@ void launchHopperF8MHA(
30733073
// nbInputSeqSplit
30743074
dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
30753075
dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3};
3076-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
3076+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
30773077
#if USE_PAGED_KV_CACHE
30783078
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
30793079
auto const dtype = [] {
@@ -3194,7 +3194,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
31943194
#if SPEC_DEC
31953195
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
31963196
#endif
3197-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
3197+
uint32_t* semaphores, void* scratch, bool enable_pdl,
3198+
cudaStream_t stream) {
31983199
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
31993200
float const factor = 0.25f;
32003201
return mha::min<uint32_t>(
@@ -3210,7 +3211,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
32103211
#endif
32113212
dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
32123213
dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3};
3213-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
3214+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
32143215
#if USE_PAGED_KV_CACHE
32153216
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
32163217
auto const dtype = [] {

csrc/xqa/mla_sm120.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,7 +1724,7 @@ void launchMLA(
17241724
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
17251725
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
17261726
// Used only for int8/fp8 KV cache.
1727-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
1727+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
17281728
#if IS_MLA
17291729
static_assert(
17301730
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
@@ -1762,7 +1762,7 @@ void launchMLA(
17621762
// nbInputSeqSplit
17631763
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
17641764
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
1765-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
1765+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
17661766
#if USE_PAGED_KV_CACHE
17671767
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
17681768
#if PAGED_KV_CACHE_LAYOUT == 1
@@ -1861,7 +1861,7 @@ void launchMLAFlashInfer(
18611861
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
18621862
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
18631863
// Used only for int8/fp8 KV cache.
1864-
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
1864+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream) {
18651865
#if IS_MLA
18661866
static_assert(
18671867
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
@@ -1885,7 +1885,7 @@ void launchMLAFlashInfer(
18851885
// nbInputSeqSplit
18861886
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
18871887
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
1888-
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
1888+
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
18891889
#if USE_PAGED_KV_CACHE
18901890
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
18911891
#if PAGED_KV_CACHE_LAYOUT == 1

csrc/xqa/xqa_wrapper.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
2828
#endif
2929
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
3030
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
31-
TensorView scratch) {
31+
TensorView scratch, bool enable_pdl) {
3232
auto stream = get_stream(output.device());
3333

3434
launchMLAFlashInfer(multiProcessorCount, 1, qScale,
@@ -44,7 +44,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
4444
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
4545
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
4646
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
47-
reinterpret_cast<void*>(scratch.data_ptr()), stream);
47+
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, stream);
4848
}
4949
#else
5050

@@ -64,7 +64,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
6464
#if SPEC_DEC
6565
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
6666
#endif
67-
TensorView semaphores, TensorView scratch) {
67+
TensorView semaphores, TensorView scratch, bool enable_pdl) {
6868
auto stream = get_stream(output.device());
6969
float const* attentionSinksPtr =
7070
attentionSinks.has_value() ? reinterpret_cast<float const*>(attentionSinks.value().data_ptr())
@@ -91,6 +91,6 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
9191
reinterpret_cast<MaskType const*>(mask.data_ptr()),
9292
#endif
9393
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
94-
reinterpret_cast<void*>(scratch.data_ptr()), stream);
94+
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, stream);
9595
}
9696
#endif

flashinfer/aot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ def gen_xqa(
404404
head_dim=head_size,
405405
head_group_ratio=head_grp_size,
406406
use_sliding_window=use_sliding_window,
407-
enable_pdl=True,
408407
)
409408

410409
if has_sm120 or has_sm121:
@@ -416,7 +415,6 @@ def gen_xqa(
416415
head_dim=576,
417416
head_group_ratio=128,
418417
use_sliding_window=False,
419-
enable_pdl=True,
420418
)
421419

422420

flashinfer/jit/xqa.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def gen_xqa_module(
4242
head_dim: int,
4343
head_group_ratio: int,
4444
use_sliding_window: bool,
45-
enable_pdl: bool,
4645
) -> JitSpec:
4746
if input_dtype == torch.float16:
4847
flag_input_dtype = ["-DINPUT_FP16=1", "-DDTYPE=__half"]
@@ -85,15 +84,10 @@ def gen_xqa_module(
8584
)
8685
sm_nvcc_flags = nvcc_flags
8786

88-
if enable_pdl:
89-
flag_enable_pdl = ["-DENABLE_PDL=2"]
90-
else:
91-
flag_enable_pdl = ["-DENABLE_PDL=0"]
92-
9387
flag_mla_wrapper = ["-DMLA_WRAPPER=0"]
9488

9589
return gen_jit_spec(
96-
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
90+
f"xqa_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
9791
[
9892
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu",
9993
jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu",
@@ -109,7 +103,6 @@ def gen_xqa_module(
109103
+ flag_kv_cache_dtype
110104
+ flag_head_group_ratio
111105
+ flag_sliding_window
112-
+ flag_enable_pdl
113106
+ flag_mla_wrapper,
114107
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
115108
extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"],
@@ -123,7 +116,6 @@ def gen_xqa_module_mla(
123116
head_dim: int,
124117
head_group_ratio: int,
125118
use_sliding_window: bool = False,
126-
enable_pdl: bool = True,
127119
) -> JitSpec:
128120
assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA"
129121
assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla"
@@ -153,15 +145,10 @@ def gen_xqa_module_mla(
153145
nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[12])
154146
sm_nvcc_flags = nvcc_flags
155147

156-
if enable_pdl:
157-
flag_enable_pdl = ["-DENABLE_PDL=2"]
158-
else:
159-
flag_enable_pdl = ["-DENABLE_PDL=0"]
160-
161148
flag_mla_wrapper = ["-DMLA_WRAPPER=1"]
162149

163150
return gen_jit_spec(
164-
f"xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}_enable_pdl_{enable_pdl}",
151+
f"xqa_mla_input_{filename_safe_dtype_map[input_dtype]}_kv_cache_{filename_safe_dtype_map[kv_cache_dtype]}_page_size_{page_size}_head_dim_{head_dim}_head_group_ratio_{head_group_ratio}_use_sliding_window_{use_sliding_window}",
165152
[
166153
jit_env.FLASHINFER_CSRC_DIR / "xqa/mla_sm120.cu",
167154
jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp",
@@ -175,8 +162,7 @@ def gen_xqa_module_mla(
175162
+ flag_kv_cache_dtype
176163
+ flag_head_group_ratio
177164
+ flag_sliding_window
178-
+ flag_mla_wrapper
179-
+ flag_enable_pdl,
165+
+ flag_mla_wrapper,
180166
extra_ldflags=["-lcuda"], # Add CUDA Driver API library
181167
extra_cflags=["-DPAGED_KV_CACHE_LAYOUT=1"],
182168
)

0 commit comments

Comments
 (0)