1818
1919#if MLA_WRAPPER
2020void xqa_wrapper_mla (int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
21- #if PAGED_KV_CACHE_LAYOUT == 1
22- TensorView kCacheVLLM , TensorView vCacheVLLM,
23- #else
24- TensorView pool,
25- #endif
26- TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
27- int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
28- TensorView scratch, bool enable_pdl);
21+ TensorView kCacheVLLM , TensorView vCacheVLLM, TensorView kvCachePageList,
22+ int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
23+ TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
24+ bool enable_pdl);
2925
3026TVM_FFI_DLL_EXPORT_TYPED_FUNC (xqa_wrapper_mla, xqa_wrapper_mla);
3127
@@ -36,14 +32,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
3632#if LOW_PREC_OUTPUT
3733 TensorView rcpOutScale,
3834#endif
39- TensorView q, tvm::ffi::Optional<TensorView> attentionSinks,
40- #if PAGED_KV_CACHE_LAYOUT == 1
41- TensorView kCacheVLLM , TensorView vCacheVLLM,
42- #else
43- TensorView pool,
44- #endif
45- TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
46- int64_t batchSize, TensorView kvCacheScale,
35+ TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM ,
36+ TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
37+ TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
4738#if SPEC_DEC
4839 int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
4940#endif
0 commit comments