Skip to content

Commit d42fb90

Browse files
authored
feat: add xqa mla backend (#2053)
1 parent f5a06a4 commit d42fb90

File tree

4 files changed

+454
-69
lines changed

4 files changed

+454
-69
lines changed

csrc/xqa/mla_sm120.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,17 +1790,17 @@ void launchMLAFlashInfer(
17901790
uint32_t const nbVHeads = nbKHeads;
17911791
uint32_t const nbQHeads = nbKHeads * headGrpSize;
17921792
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
1793-
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
1793+
/*uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
17941794
float const factor = 4.f;
17951795
return mha::min<uint32_t>(
17961796
mha::max<uint32_t>(
17971797
1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)),
17981798
divUp(maxSeqLen, tokensPerTile * 2));
1799-
}();
1799+
}();*/ // MLA disables multi-block mode for now
18001800
// printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq);
18011801
// gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x ==
18021802
// nbInputSeqSplit
1803-
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
1803+
dim3 const dimGrid{4 * inputSeqLen, 1, nbKHeads * batchSize};
18041804
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
18051805
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, enable_pdl);
18061806
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);

flashinfer/decode.py

Lines changed: 201 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323

24-
from .xqa import xqa
24+
from .xqa import xqa, xqa_mla
2525
from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
2626
from .jit import (
2727
gen_batch_decode_mla_module,
@@ -2437,11 +2437,9 @@ def xqa_batch_decode_with_kv_cache(
24372437
kv_scale_value = bmm2_scale
24382438
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
24392439

2440-
query_new = query.unsqueeze(1).contiguous()
2441-
seq_lens_new = seq_lens.unsqueeze(1).contiguous()
2442-
sinks_new = (
2443-
sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None
2444-
)
2440+
query_new = query.unsqueeze(1)
2441+
seq_lens_new = seq_lens.unsqueeze(1)
2442+
sinks_new = sinks.reshape(num_kv_heads, -1) if sinks is not None else None
24452443

24462444
# Ensure 4D output for xqa
24472445
if out is None:
@@ -2530,6 +2528,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25302528
bmm2_scale_tensor: Optional[torch.Tensor] = None,
25312529
sinks: Optional[List[torch.Tensor]] = None,
25322530
enable_pdl: bool = None,
2531+
backend: str = "auto",
25332532
) -> torch.Tensor:
25342533
"""
25352534
Parameters:
@@ -2548,6 +2547,173 @@ def trtllm_batch_decode_with_kv_cache_mla(
25482547
bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
25492548
bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
25502549
sinks: additional value per head in the denominator of the softmax.
2550+
backend : str = "auto"
2551+
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
2552+
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
2553+
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
2554+
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
2555+
2556+
Note:
2557+
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
2558+
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
2559+
bmm2_scale = v_scale * o_scale
2560+
or,
2561+
bmm1_scale_log2_tensor = [q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) * M_LOG2E]
2562+
bmm2_scale_tensor = [v_scale * o_scale]
2563+
2564+
The two scale factors should be static constant for cuda graph capture.
2565+
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
2566+
2567+
For static constant scale factors, the scale factors should be provided as float.
2568+
- (bmm1_scale, bmm2_scale)
2569+
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
2570+
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
2571+
- Currently, only fp8 tensor core operation supports this mode.
2572+
When both are provided, the dynamic scale factor tensors will be used.
2573+
"""
2574+
if backend == "auto":
2575+
backend = (
2576+
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
2577+
)
2578+
if backend == "xqa":
2579+
if (
2580+
get_compute_capability(query.device)[0] != 12
2581+
or query.dtype != torch.float8_e4m3fn
2582+
or kv_cache.dtype != torch.float8_e4m3fn
2583+
):
2584+
raise ValueError(
2585+
f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}"
2586+
)
2587+
if sinks is not None:
2588+
raise ValueError("XQA MLA does not support sinks")
2589+
if query.size(1) != 1:
2590+
raise ValueError(
2591+
f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
2592+
)
2593+
return xqa_batch_decode_with_kv_cache_mla(
2594+
query,
2595+
kv_cache,
2596+
workspace_buffer,
2597+
qk_nope_head_dim,
2598+
kv_lora_rank,
2599+
qk_rope_head_dim,
2600+
block_tables,
2601+
seq_lens,
2602+
max_seq_len,
2603+
out,
2604+
bmm1_scale,
2605+
bmm2_scale,
2606+
sinks,
2607+
enable_pdl,
2608+
)
2609+
elif backend == "trtllm-gen":
2610+
enable_pdl = (
2611+
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
2612+
)
2613+
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
2614+
sm_count = get_device_sm_count(query.device)
2615+
2616+
block_size = kv_cache.size(-2)
2617+
if (
2618+
block_size != 32 and block_size != 64
2619+
): # todo(Yingyi): add support for more block sizes?
2620+
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")
2621+
2622+
_check_trtllm_gen_mla_shape(
2623+
query,
2624+
kv_cache,
2625+
qk_nope_head_dim,
2626+
kv_lora_rank,
2627+
qk_rope_head_dim,
2628+
block_tables,
2629+
block_size,
2630+
)
2631+
2632+
if out is None:
2633+
out_shape = query.shape[:-1] + (kv_lora_rank,)
2634+
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
2635+
else:
2636+
batch_size, _, num_q_heads, _ = query.shape
2637+
check_shape_dtype_device(
2638+
out,
2639+
[batch_size, num_q_heads, kv_lora_rank],
2640+
torch.bfloat16,
2641+
query.device,
2642+
"out",
2643+
)
2644+
2645+
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
2646+
# dynamic scale factors
2647+
if (
2648+
query.dtype != torch.float8_e4m3fn
2649+
or kv_cache.dtype != torch.float8_e4m3fn
2650+
):
2651+
raise ValueError(
2652+
"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
2653+
)
2654+
2655+
run_func(
2656+
out,
2657+
None, # fp4 output not supported in wrapper api yet.
2658+
query,
2659+
kv_cache,
2660+
kv_cache,
2661+
workspace_buffer,
2662+
block_tables,
2663+
seq_lens,
2664+
max_seq_len,
2665+
bmm1_scale,
2666+
bmm2_scale,
2667+
-1, # o_sf_scale
2668+
-1, # o_sf_vec_size
2669+
0, # o_sf_start_index
2670+
-1, # window_left
2671+
sm_count,
2672+
enable_pdl,
2673+
workspace_buffer.numel() * workspace_buffer.element_size(),
2674+
sinks,
2675+
)
2676+
2677+
return out
2678+
else:
2679+
raise ValueError(f"Backend {backend} not supported")
2680+
2681+
2682+
def xqa_batch_decode_with_kv_cache_mla(
2683+
query: torch.Tensor,
2684+
kv_cache: torch.Tensor,
2685+
workspace_buffer: torch.Tensor,
2686+
qk_nope_head_dim: int,
2687+
kv_lora_rank: int,
2688+
qk_rope_head_dim: int,
2689+
block_tables: torch.Tensor,
2690+
seq_lens: torch.Tensor,
2691+
max_seq_len: int,
2692+
out: Optional[torch.Tensor] = None,
2693+
bmm1_scale: Optional[float] = 1.0,
2694+
bmm2_scale: Optional[float] = 1.0,
2695+
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
2696+
bmm2_scale_tensor: Optional[torch.Tensor] = None,
2697+
sinks: Optional[List[torch.Tensor]] = None,
2698+
enable_pdl: bool = None,
2699+
) -> torch.Tensor:
2700+
"""
2701+
Parameters:
2702+
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
2703+
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
2704+
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
2705+
qk_nope_head_dim: qk_nope_head_dim, must be 128
2706+
kv_lora_rank: kv_lora_rank, must be 512
2707+
qk_rope_head_dim: qk_rope_head_dim, must be 64
2708+
block_tables: page_table of kv cache, [batch_size, num_pages]
2709+
seq_lens: query_len
2710+
max_seq_len: max sequence length for kv_cache
2711+
out: output tensor, if not provided, will be allocated internally
2712+
bmm1_scale: fused scale for mla bmm1 input.
2713+
bmm2_scale: fused scale for mla bmm2 input.
2714+
bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
2715+
bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
2716+
sinks: additional value per head in the denominator of the softmax.
25512717
25522718
Note:
25532719
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
@@ -2568,14 +2734,20 @@ def trtllm_batch_decode_with_kv_cache_mla(
25682734
When both are provided, the dynamic scale factor tensors will be used.
25692735
"""
25702736
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
2571-
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
25722737
sm_count = get_device_sm_count(query.device)
25732738

25742739
block_size = kv_cache.size(-2)
2575-
if (
2576-
block_size != 32 and block_size != 64
2577-
): # todo(Yingyi): add support for more block sizes?
2578-
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")
2740+
q_len_per_request = query.size(1)
2741+
if q_len_per_request != 1:
2742+
raise ValueError(
2743+
f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
2744+
)
2745+
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
2746+
raise ValueError(
2747+
f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
2748+
)
2749+
if sinks is not None:
2750+
raise ValueError("XQA MLA does not support sinks")
25792751

25802752
_check_trtllm_gen_mla_shape(
25812753
query,
@@ -2600,33 +2772,27 @@ def trtllm_batch_decode_with_kv_cache_mla(
26002772
"out",
26012773
)
26022774

2603-
if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
2604-
# dynamic scale factors
2605-
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
2606-
raise ValueError(
2607-
"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
2608-
)
2775+
workspace_u8 = workspace_buffer.view(torch.uint8)
2776+
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
2777+
scratch = workspace_u8[8 * 1024 * 1024 :]
2778+
# This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same
2779+
kv_cache_new = kv_cache.squeeze(1).unsqueeze(2)
2780+
seq_lens_new = seq_lens.unsqueeze(1)
26092781

2610-
run_func(
2611-
out,
2612-
None, # fp4 output not supported in wrapper api yet.
2782+
xqa_mla(
26132783
query,
2614-
kv_cache,
2615-
kv_cache,
2616-
workspace_buffer,
2784+
kv_cache_new,
2785+
kv_cache_new,
26172786
block_tables,
2618-
seq_lens,
2619-
max_seq_len,
2620-
bmm1_scale,
2621-
bmm2_scale,
2622-
-1, # o_sf_scale
2623-
-1, # o_sf_vec_size
2624-
0, # o_sf_start_index
2625-
-1, # window_left
2626-
sm_count,
2627-
enable_pdl,
2628-
workspace_buffer.numel() * workspace_buffer.element_size(),
2629-
sinks,
2787+
seq_lens_new,
2788+
out,
2789+
scratch,
2790+
semaphore,
2791+
block_size,
2792+
q_scale=bmm1_scale,
2793+
kv_scale=bmm2_scale,
2794+
sm_count=sm_count,
2795+
enable_pdl=enable_pdl,
26302796
)
26312797

26322798
return out

0 commit comments

Comments
 (0)