@@ -128,7 +128,7 @@ void launchMHA(
128128#if SPEC_DEC
129129 SpecDecParams const & specDecParams,
130130#endif
131- uint32_t * semaphores, void * scratch, cudaStream_t stream);
131+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
132132
133133void launchMHAFlashInfer (uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
134134 float qScale, OutputHead* output,
@@ -147,7 +147,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
147147#if SPEC_DEC
148148 uint32_t qSeqLen, uint32_t const * qCuSeqLens, MaskType const * mask,
149149#endif
150- uint32_t * semaphores, void * scratch, cudaStream_t stream);
150+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
151151
152152void launchHopperF8MHA (
153153 cudaDeviceProp const & prop, uint32_t nbKHeads,
@@ -189,7 +189,7 @@ void launchHopperF8MHA(
189189#if SPEC_DEC
190190 SpecDecParams const & specDecParams,
191191#endif
192- uint32_t * semaphores, void * scratch, cudaStream_t stream);
192+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
193193
194194void launchHopperF8MHAFlashInfer (uint32_t multiProcessorCount, uint32_t nbKHeads,
195195 uint32_t slidingWinSize, float qScale, OutputHead* output,
@@ -208,7 +208,8 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
208208#if SPEC_DEC
209209 uint32_t qSeqLen, uint32_t const * qCuSeqLens, MaskType const * mask,
210210#endif
211- uint32_t * semaphores, void * scratch, cudaStream_t stream);
211+ uint32_t * semaphores, void * scratch, bool enable_pdl,
212+ cudaStream_t stream);
212213
213214void launchMLA (
214215 cudaDeviceProp const & prop,
@@ -230,7 +231,7 @@ void launchMLA(
230231 uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
231232 float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
232233 // Used only for int8/fp8 KV cache.
233- uint32_t * semaphores, void * scratch, cudaStream_t stream);
234+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
234235
235236void launchMLAFlashInfer (
236237 uint32_t multiProcessorCount,
@@ -248,7 +249,7 @@ void launchMLAFlashInfer(
248249 uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
249250 float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
250251 // Used only for int8/fp8 KV cache.
251- uint32_t * semaphores, void * scratch, cudaStream_t stream);
252+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
252253
253254#if STATIC_NB_K_HEADS
254255constexpr uint32_t nbKHeads = NB_K_HEADS;
0 commit comments