Skip to content

Commit 33db49e

Browse files
committed
add kv scale in test parameter
Signed-off-by: Qidi Sang <[email protected]>
1 parent b94cb61 commit 33db49e

File tree

5 files changed

+87
-87
lines changed

5 files changed

+87
-87
lines changed

csrc/xqa/mha.cu

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,8 +1301,7 @@ CUBIN_EXPORT __global__
13011301
#endif
13021302
#endif
13031303
uint32_t const batchSize,
1304-
float kvCacheScale, // Device memory scalar. Same scale for K and V
1305-
// cache. Used only for int8/fp8 KV cache.
1304+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
13061305
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
13071306
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
13081307
assert(allowMultiBlockMode || gridDim.x == 1);
@@ -2410,8 +2409,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha(
24102409
BeamSearchParams const beamSearchParams,
24112410
#endif
24122411
uint32_t const batchSize,
2413-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2414-
// Used only for int8/fp8 KV cache.
2412+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
24152413
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
24162414
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
24172415
#if SPEC_DEC
@@ -2442,39 +2440,39 @@ static constexpr auto kernel_mha = kernel_mha_impl;
24422440
#endif
24432441

24442442
#ifndef GENERATE_CUBIN
2445-
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
2443+
void launchMHA(
2444+
cudaDeviceProp const& prop, uint32_t nbKHeads,
24462445
#if SLIDING_WINDOW
2447-
uint32_t slidingWinSize,
2446+
uint32_t slidingWinSize,
24482447
#endif
2449-
float qScale, OutputHead* output,
2448+
float qScale, OutputHead* output,
24502449
#if LOW_PREC_OUTPUT
2451-
float const* rcpOutScale,
2450+
float const* rcpOutScale,
24522451
#endif
24532452
#if USE_INPUT_KV
2454-
InputHead const* qkv,
2453+
InputHead const* qkv,
24552454
#if ROPE_STYLE != 0
2456-
Vec<float, validElemsPerHead> const* ropeCosSin,
2455+
Vec<float, validElemsPerHead> const* ropeCosSin,
24572456
#endif
24582457
#else
2459-
InputHead const* q,
2460-
#endif
2461-
float const* attentionSinks, // [headGrpSize]
2462-
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2463-
KVCachePageIndex const*
2464-
kvCachePageList, // device pointer. shape:
2465-
// KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
2466-
uint32_t maxSeqLen, uint32_t const* seqLen,
2458+
InputHead const* q,
2459+
#endif
2460+
float const* attentionSinks, // [headGrpSize]
2461+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2462+
KVCachePageIndex const*
2463+
kvCachePageList, // device pointer. shape:
2464+
// KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
2465+
uint32_t maxSeqLen, uint32_t const* seqLen,
24672466
#if BEAM_WIDTH > 1
2468-
BeamSearchParams const& beamSearchParams,
2467+
BeamSearchParams const& beamSearchParams,
24692468
#endif
2470-
uint32_t batchSize,
2471-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2472-
// Used only for int8/fp8 KV cache.
2469+
uint32_t batchSize,
2470+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
24732471
#if SPEC_DEC
2474-
SpecDecParams const& specDecParams,
2472+
SpecDecParams const& specDecParams,
24752473
#endif
2476-
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
2477-
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
2474+
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
2475+
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
24782476
#if SPEC_DEC
24792477
auto const qSeqLen = specDecParams.qSeqLen;
24802478
auto const qCuSeqLens = specDecParams.qCuSeqLens;

csrc/xqa/mha.h

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -88,39 +88,39 @@ struct BeamSearchParams {
8888
// but we have to match trt-llm API.
8989
};
9090

91-
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
91+
void launchMHA(
92+
cudaDeviceProp const& prop, uint32_t const nbKHeads,
9293
#if SLIDING_WINDOW
93-
uint32_t slidingWinSize,
94+
uint32_t slidingWinSize,
9495
#endif
95-
float qScale, OutputHead* output,
96+
float qScale, OutputHead* output,
9697
#if LOW_PREC_OUTPUT
97-
float const* rcpOutScale,
98+
float const* rcpOutScale,
9899
#endif
99100
#if USE_INPUT_KV
100-
InputHead const* qkv,
101+
InputHead const* qkv,
101102
#if ROPE_STYLE != 0
102-
Vec<float, validElemsPerHead> const* ropeCosSin,
103+
Vec<float, validElemsPerHead> const* ropeCosSin,
103104
#endif
104105
#else
105-
InputHead const* q,
106-
#endif
107-
float const* attentionSinks, // [headGrpSize]
108-
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
109-
KVCachePageIndex const*
110-
kvCachePageList, // device pointer. shape:
111-
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
112-
uint32_t maxSeqLen, uint32_t const* seqLen,
106+
InputHead const* q,
107+
#endif
108+
float const* attentionSinks, // [headGrpSize]
109+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
110+
KVCachePageIndex const*
111+
kvCachePageList, // device pointer. shape:
112+
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
113+
uint32_t maxSeqLen, uint32_t const* seqLen,
113114
#if BEAM_WIDTH > 1
114-
BeamSearchParams const& beamSearchParams,
115+
BeamSearchParams const& beamSearchParams,
115116
#endif
116-
uint32_t batchSize,
117-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
118-
// Used only for int8/fp8 KV cache.
117+
uint32_t batchSize,
118+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
119119
#if SPEC_DEC
120-
SpecDecParams const& specDecParams,
120+
SpecDecParams const& specDecParams,
121121
#endif
122-
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
123-
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
122+
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
123+
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
124124

125125
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
126126
float qScale, OutputHead* output,
@@ -165,8 +165,7 @@ void launchHopperF8MHA(
165165
BeamSearchParams const& beamSearchParams,
166166
#endif
167167
uint32_t batchSize,
168-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
169-
// Used only for int8/fp8 KV cache.
168+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
170169
#if SPEC_DEC
171170
SpecDecParams const& specDecParams,
172171
#endif
@@ -188,18 +187,18 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
188187
uint64_t kv_stride_page, uint64_t kv_stride_token,
189188
uint64_t kv_stride_head, cudaStream_t stream);
190189

191-
void launchMLA(cudaDeviceProp const& prop,
192-
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
193-
float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM,
194-
GMemCacheHead* vCacheVLLM,
195-
KVCachePageIndex const*
196-
kvCachePageList, // device pointer. shape:
197-
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
198-
// (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1)
199-
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
200-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
201-
// Used only for int8/fp8 KV cache.
202-
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
190+
void launchMLA(
191+
cudaDeviceProp const& prop,
192+
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
193+
float qScale, OutputHead* output, InputHead const* q, GMemCacheHead* kCacheVLLM,
194+
GMemCacheHead* vCacheVLLM,
195+
KVCachePageIndex const*
196+
kvCachePageList, // device pointer. shape:
197+
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
198+
// (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1)
199+
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
200+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
201+
uint32_t* semaphores, void* scratch, bool enable_pdl, cudaStream_t stream);
203202

204203
void launchMLAFlashInfer(
205204
uint32_t multiProcessorCount,
@@ -211,8 +210,7 @@ void launchMLAFlashInfer(
211210
// KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
212211
// [batchSize][maxNbPagesPerSeq] (Layout 1)
213212
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
214-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
215-
// Used only for int8/fp8 KV cache.
213+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
216214
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
217215
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
218216

csrc/xqa/mha_sm90.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,7 @@ __launch_bounds__(128 * 3)
626626
BeamSearchParams const beamSearchParams,
627627
#endif
628628
uint32_t const batchSize,
629-
float kvCacheScale, // Device memory scalar. Same scale for K and
630-
// V cache. Used only for int8/fp8 KV cache.
629+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
631630
__grid_constant__ CUtensorMap const tensorMapVLLMK,
632631
__grid_constant__ CUtensorMap const tensorMapVLLMV,
633632
#if SPEC_DEC
@@ -2931,8 +2930,7 @@ void launchHopperF8MHA(
29312930
BeamSearchParams const& beamSearchParams,
29322931
#endif
29332932
uint32_t batchSize,
2934-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
2935-
// Used only for int8/fp8 KV cache.
2933+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
29362934
#if SPEC_DEC
29372935
SpecDecParams const& specDecParams,
29382936
#endif

csrc/xqa/mla_sm120.cu

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,7 @@ struct KernelArgs {
395395
OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads]
396396
KVCacheList<usePagedKVCache> const& cacheList;
397397
uint32_t const& batchSize;
398-
float kvCacheScale; // Device memory scalar. Same scale for K and V
399-
// cache. Used only for int8/fp8 KV cache.
398+
float kvCacheScale; // Same scale for K and V cache. Used only for int8/fp8 KV cache.
400399
Vec<CgaXBuffer, nbProducerCtasPerCga>* __restrict__ const&
401400
cgaXBuf; // [totalNbInputTokens][maxNbSubSeq]
402401
uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens]
@@ -1553,8 +1552,7 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
15531552
float const qScale,
15541553
OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads]
15551554
KVCacheList<usePagedKVCache> const cacheList, uint32_t const batchSize,
1556-
float kvCacheScale, // Device memory scalar. Same scale for K and V
1557-
// cache. Used only for int8/fp8 KV cache.
1555+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
15581556
Vec<CgaXBuffer,
15591557
nbProducerCtasPerCga>* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq]
15601558
uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens]
@@ -1648,18 +1646,18 @@ CUtensorMap makeTensorMapForQ(void const* addr, CUtensorMapDataType_enum dataTyp
16481646
}
16491647
#endif // IS_MLA
16501648

1651-
void launchMLA(cudaDeviceProp const& prop,
1652-
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
1653-
float qScale, OutputHead* output, InputHead const* q,
1654-
GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout
1655-
GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout
1656-
KVCachePageIndex const* kvCachePageList, // device pointer. shape:
1657-
// [batchSize][maxNbPagesPerSeq] (Layout 1)
1658-
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
1659-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
1660-
// Used only for int8/fp8 KV cache.
1661-
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
1662-
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
1649+
void launchMLA(
1650+
cudaDeviceProp const& prop,
1651+
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
1652+
float qScale, OutputHead* output, InputHead const* q,
1653+
GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout
1654+
GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout
1655+
KVCachePageIndex const* kvCachePageList, // device pointer. shape:
1656+
// [batchSize][maxNbPagesPerSeq] (Layout 1)
1657+
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
1658+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
1659+
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
1660+
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
16631661
#if IS_MLA
16641662
static_assert(
16651663
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0,
@@ -1778,8 +1776,7 @@ void launchMLAFlashInfer(
17781776
KVCachePageIndex const* kvCachePageList, // device pointer. shape:
17791777
// [batchSize][maxNbPagesPerSeq] (Layout 1)
17801778
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
1781-
float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
1782-
// Used only for int8/fp8 KV cache.
1779+
float kvCacheScale, // Same scale for K and V cache. Used only for int8/fp8 KV cache.
17831780
uint32_t* semaphores, void* scratch, bool enable_pdl, uint64_t kv_stride_page,
17841781
uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream) {
17851782
#if IS_MLA

tests/attention/test_xqa.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def div_up(a, b):
2929
sm_count = props.multi_processor_count
3030

3131
beam_width = 1
32-
q_scale = 1.0
3332

3433

3534
class CacheSeq:
@@ -181,6 +180,8 @@ def ref_attention(
181180
@pytest.mark.parametrize("valid_elems_per_head", [32, 128])
182181
@pytest.mark.parametrize("head_grp_size", [8, 16])
183182
@pytest.mark.parametrize("kv_layout", ["NHD", "HND"])
183+
@pytest.mark.parametrize("kv_scale", [1.0, 0.5])
184+
@pytest.mark.parametrize("q_scale", [1.0, 0.5])
184185
def test_xqa(
185186
batch_size,
186187
nb_k_heads,
@@ -194,7 +195,11 @@ def test_xqa(
194195
use_sliding_window,
195196
enable_pdl,
196197
kv_layout,
198+
kv_scale,
199+
q_scale,
197200
):
201+
if kv_scale != 1.0 and fp8_kv_cache is False:
202+
pytest.skip("kv cache scale works only for fp8 kv cache")
198203
set_random_seed(42)
199204

200205
nb_q_heads = nb_k_heads * head_grp_size
@@ -347,7 +352,7 @@ def cache_head_at(
347352
)
348353
seq_len_list.fill_(seq_len)
349354

350-
kv_cache_scale = 1.0
355+
kv_cache_scale = kv_scale
351356

352357
nb_seq = nb_k_heads * batch_size
353358
nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2
@@ -443,6 +448,8 @@ def cache_head_at(
443448
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
444449
reason="XQA mla is only supported on SM120 GPUs",
445450
)
451+
@pytest.mark.parametrize("kv_scale", [1.0, 0.5])
452+
@pytest.mark.parametrize("q_scale", [1.0, 0.5])
446453
@pytest.mark.parametrize("enable_pdl", [True, False])
447454
@pytest.mark.parametrize("seq_len", [2, 15, 256, 514, 2048])
448455
@pytest.mark.parametrize("batch_size", [1, 2])
@@ -451,6 +458,8 @@ def test_xqa_mla(
451458
batch_size,
452459
seq_len,
453460
tokens_per_page,
461+
kv_scale,
462+
q_scale,
454463
enable_pdl,
455464
):
456465
set_random_seed(42)
@@ -570,7 +579,7 @@ def cache_head_at(
570579
)
571580
seq_len_list.fill_(seq_len)
572581

573-
kv_cache_scale = 1.0
582+
kv_cache_scale = kv_scale
574583

575584
nb_seq = nb_k_heads * batch_size
576585
nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2

0 commit comments

Comments
 (0)