Skip to content

Commit b94cb61

Browse files
committed
use scalar for kv_scale
Signed-off-by: Qidi Sang <[email protected]>
1 parent 2d68a6b commit b94cb61

File tree

9 files changed

+119
-141
lines changed

9 files changed

+119
-141
lines changed

csrc/flashinfer_xqa_binding.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
#if MLA_WRAPPER
2020
void xqa_wrapper_mla(int64_t multiProcessorCount, double qScale, TensorView output, TensorView q,
2121
TensorView kCacheVLLM, TensorView vCacheVLLM, TensorView kvCachePageList,
22-
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize,
23-
TensorView kvCacheScale, TensorView semaphores, TensorView scratch,
24-
bool enable_pdl);
22+
int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, double kvCacheScale,
23+
TensorView semaphores, TensorView scratch, bool enable_pdl);
2524

2625
TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper_mla, xqa_wrapper_mla);
2726

@@ -34,7 +33,7 @@ void xqa_wrapper(bool run_sm90_fp8_mha, int64_t multiProcessorCount, int64_t nbK
3433
#endif
3534
TensorView q, tvm::ffi::Optional<TensorView> attentionSinks, TensorView kCacheVLLM,
3635
TensorView vCacheVLLM, TensorView kvCachePageList, int64_t maxSeqLen,
37-
TensorView seqLen, int64_t batchSize, TensorView kvCacheScale,
36+
TensorView seqLen, int64_t batchSize, double kvCacheScale,
3837
#if SPEC_DEC
3938
int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask,
4039
#endif

csrc/xqa/mha.cu

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,8 +1301,8 @@ CUBIN_EXPORT __global__
13011301
#endif
13021302
#endif
13031303
uint32_t const batchSize,
1304-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V
1305-
// cache. Used only for int8/fp8 KV cache.
1304+
float kvCacheScale, // Device memory scalar. Same scale for K and V
1305+
// cache. Used only for int8/fp8 KV cache.
13061306
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
13071307
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
13081308
assert(allowMultiBlockMode || gridDim.x == 1);
@@ -1503,7 +1503,7 @@ CUBIN_EXPORT __global__
15031503
};
15041504
if (warpIdx.z == 0) {
15051505
float const qkScale =
1506-
qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) *
1506+
qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) *
15071507
rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax.
15081508
CircIdx<nbKBuffers> idxCurrSMemKBuf{nbKBuffers - 1};
15091509
auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& {
@@ -2156,7 +2156,7 @@ CUBIN_EXPORT __global__
21562156
}
21572157
}
21582158

2159-
float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F);
2159+
float voScale = (isKVCacheQuantized ? kvCacheScale : 1.F);
21602160
if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN.
21612161
// The attention sinks are moved to the multi-block reduction part if the multi-block is
21622162
// enabled.
@@ -2410,8 +2410,8 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
24102410
BeamSearchParams const beamSearchParams,
24112411
#endif
24122412
uint32_t const batchSize,
2413-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2414-
// Used only for int8/fp8 KV cache.
2413+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2414+
// Used only for int8/fp8 KV cache.
24152415
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
24162416
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
24172417
#if SPEC_DEC
@@ -2442,40 +2442,39 @@ static constexpr auto kernel_mha = kernel_mha_impl;
24422442
#endif
24432443

24442444
#ifndef GENERATE_CUBIN
2445-
void launchMHA(
2446-
cudaDeviceProp const& prop, uint32_t nbKHeads,
2445+
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
24472446
#if SLIDING_WINDOW
2448-
uint32_t slidingWinSize,
2447+
uint32_t slidingWinSize,
24492448
#endif
2450-
float qScale, OutputHead* output,
2449+
float qScale, OutputHead* output,
24512450
#if LOW_PREC_OUTPUT
2452-
float const* rcpOutScale,
2451+
float const* rcpOutScale,
24532452
#endif
24542453
#if USE_INPUT_KV
2455-
InputHead const* qkv,
2454+
InputHead const* qkv,
24562455
#if ROPE_STYLE != 0
2457-
Vec<float, validElemsPerHead> const* ropeCosSin,
2456+
Vec<float, validElemsPerHead> const* ropeCosSin,
24582457
#endif
24592458
#else
2460-
InputHead const* q,
2461-
#endif
2462-
float const* attentionSinks, // [headGrpSize]
2463-
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2464-
KVCachePageIndex const*
2465-
kvCachePageList, // device pointer. shape:
2466-
// KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
2467-
uint32_t maxSeqLen, uint32_t const* seqLen,
2459+
InputHead const* q,
2460+
#endif
2461+
float const* attentionSinks, // [headGrpSize]
2462+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2463+
KVCachePageIndex const*
2464+
kvCachePageList, // device pointer. shape:
2465+
// KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
2466+
uint32_t maxSeqLen, uint32_t const* seqLen,
24682467
#if BEAM_WIDTH > 1
2469-
BeamSearchParams const& beamSearchParams,
2468+
BeamSearchParams const& beamSearchParams,
24702469
#endif
2471-
uint32_t batchSize,
2472-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2473-
// Used only for int8/fp8 KV cache.
2470+
uint32_t batchSize,
2471+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2472+
// Used only for int8/fp8 KV cache.
24742473
#if SPEC_DEC
2475-
SpecDecParams const& specDecParams,
2474+
SpecDecParams const& specDecParams,
24762475
#endif
2477-
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
2478-
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
2476+
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
2477+
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
24792478
#if SPEC_DEC
24802479
auto const qSeqLen = specDecParams.qSeqLen;
24812480
auto const qCuSeqLens = specDecParams.qCuSeqLens;
@@ -2571,7 +2570,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
25712570
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
25722571
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
25732572
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
2574-
float const* __restrict__ kvCacheScale,
2573+
float kvCacheScale,
25752574
#if SPEC_DEC
25762575
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
25772576
#endif

csrc/xqa/mha.h

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -88,40 +88,39 @@ struct BeamSearchParams {
8888
// but we have to match trt-llm API.
8989
};
9090

91-
void launchMHA(
92-
cudaDeviceProp const& prop, uint32_t const nbKHeads,
91+
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
9392
#if SLIDING_WINDOW
94-
uint32_t slidingWinSize,
93+
uint32_t slidingWinSize,
9594
#endif
96-
float qScale, OutputHead* output,
95+
float qScale, OutputHead* output,
9796
#if LOW_PREC_OUTPUT
98-
float const* rcpOutScale,
97+
float const* rcpOutScale,
9998
#endif
10099
#if USE_INPUT_KV
101-
InputHead const* qkv,
100+
InputHead const* qkv,
102101
#if ROPE_STYLE != 0
103-
Vec<float, validElemsPerHead> const* ropeCosSin,
102+
Vec<float, validElemsPerHead> const* ropeCosSin,
104103
#endif
105104
#else
106-
InputHead const* q,
107-
#endif
108-
float const* attentionSinks, // [headGrpSize]
109-
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
110-
KVCachePageIndex const*
111-
kvCachePageList, // device pointer. shape:
112-
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
113-
uint32_t maxSeqLen, uint32_t const* seqLen,
105+
InputHead const* q,
106+
#endif
107+
float const* attentionSinks, // [headGrpSize]
108+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
109+
KVCachePageIndex const*
110+
kvCachePageList, // device pointer. shape:
111+
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
112+
uint32_t maxSeqLen, uint32_t const* seqLen,
114113
#if BEAM_WIDTH > 1
115-
BeamSearchParams const& beamSearchParams,
114+
BeamSearchParams const& beamSearchParams,
116115
#endif
117-
uint32_t batchSize,
118-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
119-
// Used only for int8/fp8 KV cache.
116+
uint32_t batchSize,
117+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
118+
// Used only for int8/fp8 KV cache.
120119
#if SPEC_DEC
121-
SpecDecParams const& specDecParams,
120+
SpecDecParams const& specDecParams,
122121
#endif
123-
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
124-
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
122+
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
123+
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
125124

126125
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
127126
float qScale, OutputHead* output,
@@ -131,7 +130,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
131130
InputHead const* q, float const* attentionSinks, GMemCacheHead* kCacheVLLM,
132131
GMemCacheHead* vCacheVLLM, KVCachePageIndex const* kvCachePageList,
133132
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
134-
float const* __restrict__ kvCacheScale,
133+
float kvCacheScale,
135134
#if SPEC_DEC
136135
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
137136
#endif
@@ -166,8 +165,8 @@ void launchHopperF8MHA(
166165
BeamSearchParams const& beamSearchParams,
167166
#endif
168167
uint32_t batchSize,
169-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
170-
// Used only for int8/fp8 KV cache.
168+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
169+
// Used only for int8/fp8 KV cache.
171170
#if SPEC_DEC
172171
SpecDecParams const& specDecParams,
173172
#endif
@@ -181,28 +180,26 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
181180
InputHead const* q, float const* attentionSinks,
182181
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
183182
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
184-
uint32_t const* seqLen, uint32_t batchSize,
185-
float const* __restrict__ kvCacheScale,
183+
uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
186184
#if SPEC_DEC
187185
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
188186
#endif
189187
uint32_t* semaphores, void* scratch, bool enable_pdl,
190188
uint64_t kv_stride_page, uint64_t kv_stride_token,
191189
uint64_t kv_stride_head, cudaStream_t stream);
192190

193-
void launchMLA(
194-
cudaDeviceProp const& prop,
195-
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
196-
float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM,
197-
GMemCacheHead* vCacheVLLM,
198-
KVCachePageIndex const*
199-
kvCachePageList, // device pointer. shape:
200-
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
201-
// [batchSize][maxNbPagesPerSeq] (Layout 1)
202-
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
203-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
204-
// Used only for int8/fp8 KV cache.
205-
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
191+
void launchMLA(cudaDeviceProp const& prop,
192+
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
193+
float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM,
194+
GMemCacheHead* vCacheVLLM,
195+
KVCachePageIndex const*
196+
kvCachePageList, // device pointer. shape:
197+
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
198+
// (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1)
199+
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
200+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
201+
// Used only for int8/fp8 KV cache.
202+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
206203

207204
void launchMLAFlashInfer(
208205
uint32_t multiProcessorCount,
@@ -214,8 +211,8 @@ void launchMLAFlashInfer(
214211
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
215212
// [batchSize][maxNbPagesPerSeq] (Layout 1)
216213
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
217-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
218-
// Used only for int8/fp8 KV cache.
214+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
215+
// Used only for int8/fp8 KV cache.
219216
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
220217
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
221218

csrc/xqa/mha_sm90.cu

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,8 @@ __launch_bounds__(128 * 3)
626626
BeamSearchParams const beamSearchParams,
627627
#endif
628628
uint32_t const batchSize,
629-
float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and
630-
// V cache. Used only for int8/fp8 KV cache.
629+
float kvCacheScale, // Device memory scalar. Same scale for K and
630+
// V cache. Used only for int8/fp8 KV cache.
631631
__grid_constant__ CUtensorMap const tensorMapVLLMK,
632632
__grid_constant__ CUtensorMap const tensorMapVLLMV,
633633
#if SPEC_DEC
@@ -773,7 +773,7 @@ __launch_bounds__(128 * 3)
773773
}
774774

775775
float const qkScale =
776-
qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) *
776+
qScale * (isKVCacheQuantized ? kvCacheScale : 1.f) *
777777
rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax.
778778
uint32_t const warpRank = warpIdx.x;
779779

@@ -962,7 +962,7 @@ __launch_bounds__(128 * 3)
962962
#else
963963
constexpr float oScale = 1.F;
964964
#endif
965-
float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale;
965+
float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale : 1.f) * oScale;
966966

967967
Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction.
968968
gmma::fence();
@@ -1316,7 +1316,7 @@ __launch_bounds__(128 * 3)
13161316
headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq;
13171317
IOHead const& inKHead = qkv[inputKHeadOffset];
13181318
uint32_t const lane = laneId();
1319-
float const rcpKScale = 1.F / kvCacheScale[0];
1319+
float const rcpKScale = 1.F / kvCacheScale;
13201320
#if ROPE_STYLE == 0
13211321
constexpr bool isNeox = false;
13221322
auto const pairs =
@@ -1375,7 +1375,7 @@ __launch_bounds__(128 * 3)
13751375
(headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq;
13761376
IOHead const& inVHead = qkv[inputVHeadOffset];
13771377
uint32_t const lane = laneId();
1378-
float const rcpVScale = 1.F / kvCacheScale[0];
1378+
float const rcpVScale = 1.F / kvCacheScale;
13791379
constexpr bool isNeox = false;
13801380
auto const pairs =
13811381
loadHead<InputElem, isNeox, warp_size, float>(inVHead, lane) * rcpVScale;
@@ -2931,8 +2931,8 @@ void launchHopperF8MHA(
29312931
BeamSearchParams const& beamSearchParams,
29322932
#endif
29332933
uint32_t batchSize,
2934-
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2935-
// Used only for int8/fp8 KV cache.
2934+
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2935+
// Used only for int8/fp8 KV cache.
29362936
#if SPEC_DEC
29372937
SpecDecParams const& specDecParams,
29382938
#endif
@@ -3044,8 +3044,7 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
30443044
InputHead const* q, float const* attentionSinks,
30453045
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
30463046
KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen,
3047-
uint32_t const* seqLen, uint32_t batchSize,
3048-
float const* __restrict__ kvCacheScale,
3047+
uint32_t const* seqLen, uint32_t batchSize, float kvCacheScale,
30493048
#if SPEC_DEC
30503049
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
30513050
#endif

0 commit comments

Comments
 (0)