Skip to content

Commit 67bdec3

Browse files
committed
use VLLM kv layout, but some ut failed
Signed-off-by: Qidi Sang <[email protected]>
1 parent 82957fc commit 67bdec3

File tree

8 files changed

+149
-82
lines changed

8 files changed

+149
-82
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads
2121
#if LOW_PREC_OUTPUT
2222
TensorView rcpOutScale,
2323
#endif
24-
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView pool,
24+
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks,
25+
#if PAGED_KV_CACHE_LAYOUT == 1
26+
TensorView kCacheVLLM, TensorView vCacheVLLM,
27+
#else
28+
TensorView pool,
29+
#endif
2530
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
2631
int64_t batchSize, TensorView kvCacheScale,
2732
#if SPEC_DEC

csrc/xqa/mha.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2659,7 +2659,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
26592659
#if LOW_PREC_OUTPUT
26602660
float const* rcpOutScale,
26612661
#endif
2662-
InputHead const* q, float const* attentionSinks, GMemCacheHead* pool,
2662+
InputHead const* q, float const* attentionSinks,
2663+
#if PAGED_KV_CACHE_LAYOUT == 1
2664+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2665+
#else
2666+
GMemCacheHead* pool,
2667+
#endif
26632668
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
26642669
uint32_t const* seqLen, uint32_t batchSize,
26652670
float const* __restrict__ kvCacheScale,
@@ -2691,7 +2696,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
26912696
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
26922697
#if USE_PAGED_KV_CACHE
26932698
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
2699+
#if PAGED_KV_CACHE_LAYOUT == 1
2700+
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen,
2701+
maxNbPagesPerSeq};
2702+
#else
26942703
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
2704+
#endif
26952705
cudaLaunchKernelEx(&launchCfg, kernel_mha,
26962706
#if SPEC_DEC
26972707
qSeqLen, nbKHeads, headGrpSize, qCuSeqLens,

csrc/xqa/mha.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
135135
#if LOW_PREC_OUTPUT
136136
float const* rcpOutScale,
137137
#endif
138-
InputHead const* q, float const* attentionSinks, GMemCacheHead* pool,
138+
InputHead const* q, float const* attentionSinks,
139+
#if PAGED_KV_CACHE_LAYOUT == 1
140+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
141+
#else
142+
GMemCacheHead* pool,
143+
#endif
139144
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
140145
uint32_t const* seqLen, uint32_t batchSize,
141146
float const* __restrict__ kvCacheScale,
@@ -192,8 +197,13 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
192197
float const* rcpOutScale,
193198
#endif
194199
InputHead const* q, float const* attentionSinks,
195-
GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList,
196-
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
200+
#if PAGED_KV_CACHE_LAYOUT == 1
201+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
202+
#else
203+
GMemCacheHead* pool,
204+
#endif
205+
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
206+
uint32_t const* seqLen, uint32_t batchSize,
197207
float const* __restrict__ kvCacheScale,
198208
#if SPEC_DEC
199209
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,

csrc/xqa/mha_sm90.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3171,8 +3171,13 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
31713171
float const* rcpOutScale,
31723172
#endif
31733173
InputHead const* q, float const* attentionSinks,
3174-
GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList,
3175-
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
3174+
#if PAGED_KV_CACHE_LAYOUT == 1
3175+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
3176+
#else
3177+
GMemCacheHead* pool,
3178+
#endif
3179+
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
3180+
uint32_t const* seqLen, uint32_t batchSize,
31763181
float const* __restrict__ kvCacheScale,
31773182
#if SPEC_DEC
31783183
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,

csrc/xqa/xqa_wrapper.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads
2424
#if LOW_PREC_OUTPUT
2525
TensorView rcpOutScale,
2626
#endif
27-
TensorView q, Optional<TensorView> attentionSinks, TensorView pool,
27+
TensorView q, Optional<TensorView> attentionSinks,
28+
#if PAGED_KV_CACHE_LAYOUT == 1
29+
TensorView kCacheVLLM, TensorView vCacheVLLM,
30+
#else
31+
TensorView pool,
32+
#endif
2833
TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen,
2934
int64_t batchSize, TensorView kvCacheScale,
3035
#if SPEC_DEC
@@ -43,7 +48,12 @@ void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads
4348
reinterpret_cast<float const*>(rcpOutScale->data),
4449
#endif
4550
reinterpret_cast<InputHead const*>(q->data), attentionSinksPtr,
51+
#if PAGED_KV_CACHE_LAYOUT == 1
52+
reinterpret_cast<GMemCacheHead*>(kCacheVLLM->data),
53+
reinterpret_cast<GMemCacheHead*>(vCacheVLLM->data),
54+
#else
4655
reinterpret_cast<GMemCacheHead*>(pool->data),
56+
#endif
4757
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList->data), maxSeqLen,
4858
reinterpret_cast<uint32_t const*>(seqLen->data), batchSize,
4959
reinterpret_cast<float const*>(kvCacheScale->data),

flashinfer/jit/xqa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
xqa_nvcc_flags = [
2727
"-DNDEBUG=1",
28+
"-DUSE_PAGED_KV_CACHE=1",
29+
"-DPAGED_KV_CACHE_LAYOUT=1",
2830
"-DBEAM_WIDTH=1",
2931
"-DUSE_INPUT_KV=0",
3032
"-DUSE_CUSTOM_BARRIER=1",

flashinfer/xqa.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def xqa(
6060
output: torch.Tensor,
6161
q: torch.Tensor,
6262
attentionSinks: Optional[torch.Tensor],
63-
pool: torch.Tensor,
63+
kCacheVLLM: torch.Tensor,
64+
vCacheVLLM: torch.Tensor,
6465
kvCachePageList: torch.Tensor,
6566
maxSeqLen: int,
6667
seqLen: torch.Tensor,
@@ -78,7 +79,8 @@ def xqa(
7879
output,
7980
q,
8081
attentionSinks,
81-
pool,
82+
kCacheVLLM,
83+
vCacheVLLM,
8284
kvCachePageList,
8385
maxSeqLen,
8486
seqLen,
@@ -100,7 +102,8 @@ def _fake_xqa(
100102
output: torch.Tensor,
101103
q: torch.Tensor,
102104
attentionSinks: Optional[torch.Tensor],
103-
pool: torch.Tensor,
105+
kCacheVLLM: torch.Tensor,
106+
vCacheVLLM: torch.Tensor,
104107
kvCachePageList: torch.Tensor,
105108
maxSeqLen: int,
106109
seqLen: torch.Tensor,
@@ -131,7 +134,8 @@ def xqa(
131134
output: torch.Tensor,
132135
q: torch.Tensor,
133136
attentionSinks: Optional[torch.Tensor],
134-
pool: torch.Tensor,
137+
kCacheVLLM: torch.Tensor,
138+
vCacheVLLM: torch.Tensor,
135139
kvCachePageList: torch.Tensor,
136140
maxSeqLen: int,
137141
seqLen: torch.Tensor,
@@ -161,7 +165,8 @@ def xqa(
161165
output,
162166
q,
163167
attentionSinks,
164-
pool,
168+
kCacheVLLM,
169+
vCacheVLLM,
165170
kvCachePageList,
166171
maxSeqLen,
167172
seqLen,

0 commit comments

Comments
 (0)