@@ -135,7 +135,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
135135#if LOW_PREC_OUTPUT
136136 float const * rcpOutScale,
137137#endif
138- InputHead const * q, float const * attentionSinks, GMemCacheHead* pool,
138+ InputHead const * q, float const * attentionSinks,
139+ #if PAGED_KV_CACHE_LAYOUT == 1
140+ GMemCacheHead* kCacheVLLM , GMemCacheHead* vCacheVLLM,
141+ #else
142+ GMemCacheHead* pool,
143+ #endif
139144 KVCachePageIndex const * kvCachePageList, uint32_t maxSeqLen,
140145 uint32_t const * seqLen, uint32_t batchSize,
141146 float const * __restrict__ kvCacheScale,
@@ -192,8 +197,13 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
192197 float const * rcpOutScale,
193198#endif
194199 InputHead const * q, float const * attentionSinks,
195- GMemCacheHead* pool, KVCachePageIndex const * kvCachePageList,
196- uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
200+ #if PAGED_KV_CACHE_LAYOUT == 1
201+ GMemCacheHead* kCacheVLLM , GMemCacheHead* vCacheVLLM,
202+ #else
203+ GMemCacheHead* pool,
204+ #endif
205+ KVCachePageIndex const * kvCachePageList, uint32_t maxSeqLen,
206+ uint32_t const * seqLen, uint32_t batchSize,
197207 float const * __restrict__ kvCacheScale,
198208#if SPEC_DEC
199209 uint32_t qSeqLen, uint32_t const * qCuSeqLens, MaskType const * mask,
0 commit comments