Skip to content

Commit 869c0c1

Browse files
committed
[wip] xqa nhd/hnd
upd
1 parent 5c6b9d9 commit 869c0c1

File tree

14 files changed

+210
-722
lines changed

14 files changed

+210
-722
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,10 @@
1818

1919
#if MLA_WRAPPER
2020
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
21-
#if PAGED_KV_CACHE_LAYOUT == 1
22-
TensorView kCacheVLLM, TensorView vCacheVLLM,
23-
#else
24-
TensorView pool,
25-
#endif
26-
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
27-
int64_t batchSize, TensorView kvCacheScale, TensorView semaphores,
28-
TensorView scratch, bool enable_pdl);
21+
TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList,
22+
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
23+
TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
24+
bool enable_pdl);
2925

3026
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
3127

@@ -36,14 +32,9 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
3632
#if LOW_PREC_OUTPUT
3733
TensorView rcpOutScale,
3834
#endif
39-
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks,
40-
#if PAGED_KV_CACHE_LAYOUT == 1
41-
TensorView kCacheVLLM, TensorView vCacheVLLM,
42-
#else
43-
TensorView pool,
44-
#endif
45-
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
46-
int64_t batchSize, TensorView kvCacheScale,
35+
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
36+
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
37+
TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
4738
#if SPEC_DEC
4839
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
4940
#endif

csrc/xqa/defines.h

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,6 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
9292
#define TOKENS_PER_PAGE 32
9393
#endif
9494

95-
// don't modify
96-
#ifndef USE_PAGED_KV_CACHE
97-
#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0)
98-
#endif
99-
100-
// Paged KV Cache Format
101-
// 0 - XQA Original
102-
// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for
103-
// VLLM/SGLang
104-
#ifdef USE_PAGED_KV_CACHE
105-
#ifndef PAGED_KV_CACHE_LAYOUT
106-
#define PAGED_KV_CACHE_LAYOUT 0
107-
#endif
108-
#endif
109-
11095
// don't modify
11196
#define USE_BEAM_SEARCH (BEAM_WIDTH > 1)
11297

@@ -170,8 +155,7 @@ static_assert(CACHE_ELEM_ENUM != 0);
170155
#endif
171156

172157
// true should be better if warpTile.x * cacheElemSize < 128. otherwise use false.
173-
#define GRP_LOAD_V \
174-
(CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && USE_PAGED_KV_CACHE && BEAM_WIDTH > 1)
158+
#define GRP_LOAD_V (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && BEAM_WIDTH > 1)
175159

176160
// use custom barrier for NVRTC to avoid pulling in many headers
177161
#ifndef USE_CUSTOM_BARRIER

0 commit comments

Comments
 (0)