Skip to content

Commit 6d19a75

Browse files
authored
use scalar for kv_scale in xqa (#2033)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Breaking Changes** * Public xqa/xqa_mla entry points now accept kv_scale as a plain float (default 1.0) instead of a 1-element tensor. Update call sites accordingly. * **Documentation** * Docstrings updated to reflect kv_scale as float. * **Tests** * Tests updated to pass scalar kv_scale, with added parameterization and conditional skip for FP8 kv-cache scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Qidi Sang <[email protected]>
1 parent 579012b commit 6d19a75

File tree

9 files changed

+60
-82
lines changed

9 files changed

+60
-82
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
#if MLA_WRAPPER
2020
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
2121
TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList,
22-
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
23-
TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
24-
bool enable_pdl);
22+
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale,
23+
TensorView semaphores, TensorView scratch, bool enable_pdl);
2524

2625
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
2726

@@ -34,7 +33,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
3433
#endif
3534
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
3635
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
37-
TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
36+
TensorView seqLen, int64_t batchSize, double kvCacheScale,
3837
#if SPEC_DEC
3938
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
4039
#endif

csrc/xqa/mha.cu

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,8 +1301,7 @@ CUBIN_EXPORT __global__
13011301
#endif
13021302
#endif
13031303
uint32_t const batchSize,
1304-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V
1305-
// cache. Used only for int8/fp8 KV cache.
1304+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
13061305
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
13071306
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
13081307
assert(allowMultiBlockMode || gridDim.x == 1);
@@ -1503,7 +1502,7 @@ CUBIN_EXPORT __global__
15031502
};
15041503
if (warpIdx.z == 0) {
15051504
float const qkScale =
1506-
qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) *
1505+
qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) *
15071506
rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax.
15081507
CircIdx<nbKBuffers> idxCurrSMemKBuf{nbKBuffers - 1};
15091508
auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& {
@@ -2156,7 +2155,7 @@ CUBIN_EXPORT __global__
21562155
}
21572156
}
21582157

2159-
float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F);
2158+
float voScale = (isKVCacheQuantized ? kvCacheScale : 1.F);
21602159
if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN.
21612160
// The attention sinks are moved to the multi-block reduction part if the multi-block is
21622161
// enabled.
@@ -2410,8 +2409,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
24102409
BeamSearchParams const beamSearchParams,
24112410
#endif
24122411
uint32_t const batchSize,
2413-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2414-
// Used only for int8/fp8 KV cache.
2412+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
24152413
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
24162414
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
24172415
#if SPEC_DEC
@@ -2469,8 +2467,7 @@ void launchMHA(
24692467
BeamSearchParams const& beamSearchParams,
24702468
#endif
24712469
uint32_t batchSize,
2472-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2473-
// Used only for int8/fp8 KV cache.
2470+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
24742471
#if SPEC_DEC
24752472
SpecDecParams const& specDecParams,
24762473
#endif
@@ -2571,7 +2568,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
25712568
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
25722569
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
25732570
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
2574-
float const* __restrict__ kvCacheScale,
2571+
float kvCacheScale,
25752572
#if SPEC_DEC
25762573
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
25772574
#endif

csrc/xqa/mha.h

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ void launchMHA(
115115
BeamSearchParams const& beamSearchParams,
116116
#endif
117117
uint32_t batchSize,
118-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
119-
// Used only for int8/fp8 KV cache.
118+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
120119
#if SPEC_DEC
121120
SpecDecParams const& specDecParams,
122121
#endif
@@ -131,7 +130,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
131130
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
132131
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
133132
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
134-
float const* __restrict__ kvCacheScale,
133+
float kvCacheScale,
135134
#if SPEC_DEC
136135
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
137136
#endif
@@ -166,8 +165,7 @@ void launchHopperF8MHA(
166165
BeamSearchParams const& beamSearchParams,
167166
#endif
168167
uint32_t batchSize,
169-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
170-
// Used only for int8/fp8 KV cache.
168+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
171169
#if SPEC_DEC
172170
SpecDecParams const& specDecParams,
173171
#endif
@@ -181,8 +179,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
181179
InputHead const* q, float const* attentionSinks,
182180
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
183181
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
184-
uint32_t const* seqLen, uint32_t batchSize,
185-
float const* __restrict__ kvCacheScale,
182+
uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
186183
#if SPEC_DEC
187184
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
188185
#endif
@@ -197,11 +194,10 @@ void launchMLA(
197194
GMemCacheHead* vCacheVLLM,
198195
KVCachePageIndex const*
199196
kvCachePageList, // device pointer. shape:
200-
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
201-
// [batchSize][maxNbPagesPerSeq] (Layout 1)
197+
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
198+
// (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1)
202199
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
203-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
204-
// Used only for int8/fp8 KV cache.
200+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
205201
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
206202

207203
void launchMLAFlashInfer(
@@ -214,8 +210,7 @@ void launchMLAFlashInfer(
214210
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
215211
// [batchSize][maxNbPagesPerSeq] (Layout 1)
216212
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
217-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
218-
// Used only for int8/fp8 KV cache.
213+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
219214
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
220215
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
221216

csrc/xqa/mha_sm90.cu

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,7 @@ __launch_bounds__(128 * 3)
626626
BeamSearchParams const beamSearchParams,
627627
#endif
628628
uint32_t const batchSize,
629-
float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and
630-
// V cache. Used only for int8/fp8 KV cache.
629+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
631630
__grid_constant__ CUtensorMap const tensorMapVLLMK,
632631
__grid_constant__ CUtensorMap const tensorMapVLLMV,
633632
#if SPEC_DEC
@@ -773,7 +772,7 @@ __launch_bounds__(128 * 3)
773772
}
774773

775774
float const qkScale =
776-
qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) *
775+
qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) *
777776
rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax.
778777
uint32_t const warpRank = warpIdx.x;
779778

@@ -962,7 +961,7 @@ __launch_bounds__(128 * 3)
962961
#else
963962
constexpr float oScale = 1.F;
964963
#endif
965-
float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale;
964+
float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * oScale;
966965

967966
Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction.
968967
gmma::fence();
@@ -1316,7 +1315,7 @@ __launch_bounds__(128 * 3)
13161315
headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq;
13171316
IOHead const& inKHead = qkv[inputKHeadOffset];
13181317
uint32_t const lane = laneId();
1319-
float const rcpKScale = 1.F / kvCacheScale[0];
1318+
float const rcpKScale = 1.F / kvCacheScale;
13201319
#if ROPE_STYLE == 0
13211320
constexpr bool isNeox = false;
13221321
auto const pairs =
@@ -1375,7 +1374,7 @@ __launch_bounds__(128 * 3)
13751374
(headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq;
13761375
IOHead const& inVHead = qkv[inputVHeadOffset];
13771376
uint32_t const lane = laneId();
1378-
float const rcpVScale = 1.F / kvCacheScale[0];
1377+
float const rcpVScale = 1.F / kvCacheScale;
13791378
constexpr bool isNeox = false;
13801379
auto const pairs =
13811380
loadHead<InputElem, isNeox, warp_size, float>(inVHead, lane) * rcpVScale;
@@ -2931,8 +2930,7 @@ void launchHopperF8MHA(
29312930
BeamSearchParams const& beamSearchParams,
29322931
#endif
29332932
uint32_t batchSize,
2934-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2935-
// Used only for int8/fp8 KV cache.
2933+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
29362934
#if SPEC_DEC
29372935
SpecDecParams const& specDecParams,
29382936
#endif
@@ -3044,8 +3042,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
30443042
InputHead const* q, float const* attentionSinks,
30453043
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
30463044
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
3047-
uint32_t const* seqLen, uint32_t batchSize,
3048-
float const* __restrict__ kvCacheScale,
3045+
uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
30493046
#if SPEC_DEC
30503047
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
30513048
#endif

csrc/xqa/mla_sm120.cu

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,7 @@ struct KernelArgs {
395395
OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads]
396396
KVCacheList<usePagedKVCache> const& cacheList;
397397
uint32_t const& batchSize;
398-
float const* __restrict__ const& kvCacheScale; // Device memory scalar. Same scale for K and V
399-
// cache. Used only for int8/fp8 KV cache.
398+
float kvCacheScale; // Same scale for K and V cache. Used only for int8/fp8 KV cache.
400399
Vec<CgaXBuffer, nbProducerCtasPerCga>* __restrict__ const&
401400
cgaXBuf; // [totalNbInputTokens][maxNbSubSeq]
402401
uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens]
@@ -449,7 +448,7 @@ struct Producer {
449448
__syncthreads();
450449
#endif
451450
if (threadIdx.x == 0) {
452-
smem.qkScaleLog2e = args.qScale * args.kvCacheScale[0] * log2e;
451+
smem.qkScaleLog2e = args.qScale * args.kvCacheScale * log2e;
453452
}
454453

455454
if (threadIdx.x < headGrpSize) {
@@ -1228,7 +1227,7 @@ __device__ inline void Consumer::compute() {
12281227

12291228
ThrdRegRowMax const accRowSum =
12301229
loadShmRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], tileBase.y, lane);
1231-
float const xvScale = computeRowSumFromF8 ? args.kvCacheScale[0] : args.kvCacheScale[0] * xScale;
1230+
float const xvScale = computeRowSumFromF8 ? args.kvCacheScale : args.kvCacheScale * xScale;
12321231
WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane);
12331232

12341233
bool const isMultiBlockMode = (nbSubSeq != 1);
@@ -1553,8 +1552,7 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
15531552
float const qScale,
15541553
OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads]
15551554
KVCacheList<usePagedKVCache> const cacheList, uint32_t const batchSize,
1556-
float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V
1557-
// cache. Used only for int8/fp8 KV cache.
1555+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
15581556
Vec<CgaXBuffer,
15591557
nbProducerCtasPerCga>* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq]
15601558
uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens]
@@ -1657,8 +1655,7 @@ void launchMLA(
16571655
KVCachePageIndex const* kvCachePageList, // device pointer. shape:
16581656
// [batchSize][maxNbPagesPerSeq] (Layout 1)
16591657
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
1660-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
1661-
// Used only for int8/fp8 KV cache.
1658+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
16621659
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
16631660
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
16641661
#if IS_MLA
@@ -1779,8 +1776,7 @@ void launchMLAFlashInfer(
17791776
KVCachePageIndex const* kvCachePageList, // device pointer. shape:
17801777
// [batchSize][maxNbPagesPerSeq] (Layout 1)
17811778
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
1782-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
1783-
// Used only for int8/fp8 KV cache.
1779+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
17841780
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
17851781
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
17861782
#if IS_MLA

csrc/xqa/xqa_wrapper.cu

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ using tvm::ffi::Optional;
2222
#if MLA_WRAPPER
2323
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
2424
TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList,
25-
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
26-
TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
27-
bool enable_pdl) {
25+
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale,
26+
TensorView semaphores, TensorView scratch, bool enable_pdl) {
2827
auto stream = get_stream(output.device());
2928

3029
// Extract strides from TensorView (in elements, not bytes)
@@ -39,8 +38,7 @@ void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView outp
3938
reinterpret_cast<GMemCacheHead*>(vCacheVLLM.data_ptr()),
4039
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
4140
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
42-
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
43-
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
41+
kvCacheScale, reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
4442
reinterpret_cast<void*>(scratch.data_ptr()), enable_pdl, kv_stride_page,
4543
kv_stride_token, kv_stride_head, stream);
4644
}
@@ -53,7 +51,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
5351
#endif
5452
TensorView q, Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
5553
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
56-
TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
54+
TensorView seqLen, int64_t batchSize, double kvCacheScale,
5755
#if SPEC_DEC
5856
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
5957
#endif
@@ -78,8 +76,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
7876
reinterpret_cast<GMemCacheHead*>(kCacheVLLM.data_ptr()),
7977
reinterpret_cast<GMemCacheHead*>(vCacheVLLM.data_ptr()),
8078
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
81-
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
82-
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
79+
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize, kvCacheScale,
8380
#if SPEC_DEC
8481
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
8582
reinterpret_cast<MaskType const*>(mask.data_ptr()),

flashinfer/decode.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2461,9 +2461,7 @@ def xqa_batch_decode_with_kv_cache(
24612461
page_size,
24622462
sinks=sinks_new,
24632463
q_scale=q_scale_value,
2464-
kv_scale=torch.tensor(
2465-
[kv_scale_value], dtype=torch.float32, device=query.device
2466-
),
2464+
kv_scale=kv_scale_value,
24672465
sliding_win_size=window_left + 1 if window_left >= 0 else 0,
24682466
kv_layout=kv_layout,
24692467
sm_count=sm_count,

0 commit comments

Comments
 (0)