2121
2222import torch
2323
24- from .xqa import xqa
24+ from .xqa import xqa , xqa_mla
2525from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
2626from .jit import (
2727 gen_batch_decode_mla_module ,
@@ -2437,11 +2437,9 @@ def xqa_batch_decode_with_kv_cache(
24372437 kv_scale_value = bmm2_scale
24382438 q_scale_value = bmm1_scale / kv_scale_value * (head_dim ** 0.5 )
24392439
2440- query_new = query .unsqueeze (1 ).contiguous ()
2441- seq_lens_new = seq_lens .unsqueeze (1 ).contiguous ()
2442- sinks_new = (
2443- sinks .reshape (num_kv_heads , - 1 ).contiguous () if sinks is not None else None
2444- )
2440+ query_new = query .unsqueeze (1 )
2441+ seq_lens_new = seq_lens .unsqueeze (1 )
2442+ sinks_new = sinks .reshape (num_kv_heads , - 1 ) if sinks is not None else None
24452443
24462444 # Ensure 4D output for xqa
24472445 if out is None :
@@ -2530,6 +2528,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
25302528 bmm2_scale_tensor : Optional [torch .Tensor ] = None ,
25312529 sinks : Optional [List [torch .Tensor ]] = None ,
25322530 enable_pdl : bool = None ,
2531+ backend : str = "auto" ,
25332532) -> torch .Tensor :
25342533 """
25352534 Parameters:
@@ -2548,6 +2547,173 @@ def trtllm_batch_decode_with_kv_cache_mla(
25482547 bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
25492548 bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
25502549 sinks: additional value per head in the denominator of the softmax.
2550+ backend : str = "auto"
2551+ The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
2552+ When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
2553+ For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
2554+ For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
2555+
2556+ Note:
2557+ In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
2558+ bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
2559+ bmm2_scale = v_scale * o_scale
2560+ or,
2561+ bmm1_scale_log2_tensor = [q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) * M_LOG2E]
2562+ bmm2_scale_tensor = [v_scale * o_scale]
2563+
2564+ The two scale factors should be static constant for cuda graph capture.
2565+ Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
2566+
2567+ For static constant scale factors, the scale factors should be provided as float.
2568+ - (bmm1_scale, bmm2_scale)
2569+ For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
2570+ - (bmm1_scale_log2_tensor, bmm2_scale_tensor)
2571+ - Currently, only fp8 tensor core operation supports this mode.
2572+ When both are provided, the dynamic scale factor tensors will be used.
2573+ """
2574+ if backend == "auto" :
2575+ backend = (
2576+ "trtllm-gen" if get_compute_capability (query .device )[0 ] == 10 else "xqa"
2577+ )
2578+ if backend == "xqa" :
2579+ if (
2580+ get_compute_capability (query .device )[0 ] != 12
2581+ or query .dtype != torch .float8_e4m3fn
2582+ or kv_cache .dtype != torch .float8_e4m3fn
2583+ ):
2584+ raise ValueError (
2585+ f"XQA MLA only supports fp8 operation on SM120 GPUs, got { query .dtype } and { kv_cache .dtype } "
2586+ )
2587+ if sinks is not None :
2588+ raise ValueError ("XQA MLA does not support sinks" )
2589+ if query .size (1 ) != 1 :
2590+ raise ValueError (
2591+ f"XQA MLA only supports q_len_per_request == 1, got { query .size (1 )} "
2592+ )
2593+ return xqa_batch_decode_with_kv_cache_mla (
2594+ query ,
2595+ kv_cache ,
2596+ workspace_buffer ,
2597+ qk_nope_head_dim ,
2598+ kv_lora_rank ,
2599+ qk_rope_head_dim ,
2600+ block_tables ,
2601+ seq_lens ,
2602+ max_seq_len ,
2603+ out ,
2604+ bmm1_scale ,
2605+ bmm2_scale ,
2606+ sinks ,
2607+ enable_pdl ,
2608+ )
2609+ elif backend == "trtllm-gen" :
2610+ enable_pdl = (
2611+ device_support_pdl (query .device ) if enable_pdl is None else enable_pdl
2612+ )
2613+ run_func = get_trtllm_gen_fmha_module ().trtllm_paged_attention_decode
2614+ sm_count = get_device_sm_count (query .device )
2615+
2616+ block_size = kv_cache .size (- 2 )
2617+ if (
2618+ block_size != 32 and block_size != 64
2619+ ): # todo(Yingyi): add support for more block sizes?
2620+ raise ValueError (f"Supported block_size are 32 and 64, got { block_size } " )
2621+
2622+ _check_trtllm_gen_mla_shape (
2623+ query ,
2624+ kv_cache ,
2625+ qk_nope_head_dim ,
2626+ kv_lora_rank ,
2627+ qk_rope_head_dim ,
2628+ block_tables ,
2629+ block_size ,
2630+ )
2631+
2632+ if out is None :
2633+ out_shape = query .shape [:- 1 ] + (kv_lora_rank ,)
2634+ out = torch .empty (out_shape , dtype = torch .bfloat16 , device = query .device )
2635+ else :
2636+ batch_size , _ , num_q_heads , _ = query .shape
2637+ check_shape_dtype_device (
2638+ out ,
2639+ [batch_size , num_q_heads , kv_lora_rank ],
2640+ torch .bfloat16 ,
2641+ query .device ,
2642+ "out" ,
2643+ )
2644+
2645+ if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None :
2646+ # dynamic scale factors
2647+ if (
2648+ query .dtype != torch .float8_e4m3fn
2649+ or kv_cache .dtype != torch .float8_e4m3fn
2650+ ):
2651+ raise ValueError (
2652+ "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
2653+ )
2654+
2655+ run_func (
2656+ out ,
2657+ None , # fp4 output not supported in wrapper api yet.
2658+ query ,
2659+ kv_cache ,
2660+ kv_cache ,
2661+ workspace_buffer ,
2662+ block_tables ,
2663+ seq_lens ,
2664+ max_seq_len ,
2665+ bmm1_scale ,
2666+ bmm2_scale ,
2667+ - 1 , # o_sf_scale
2668+ - 1 , # o_sf_vec_size
2669+ 0 , # o_sf_start_index
2670+ - 1 , # window_left
2671+ sm_count ,
2672+ enable_pdl ,
2673+ workspace_buffer .numel () * workspace_buffer .element_size (),
2674+ sinks ,
2675+ )
2676+
2677+ return out
2678+ else :
2679+ raise ValueError (f"Backend { backend } not supported" )
2680+
2681+
2682+ def xqa_batch_decode_with_kv_cache_mla (
2683+ query : torch .Tensor ,
2684+ kv_cache : torch .Tensor ,
2685+ workspace_buffer : torch .Tensor ,
2686+ qk_nope_head_dim : int ,
2687+ kv_lora_rank : int ,
2688+ qk_rope_head_dim : int ,
2689+ block_tables : torch .Tensor ,
2690+ seq_lens : torch .Tensor ,
2691+ max_seq_len : int ,
2692+ out : Optional [torch .Tensor ] = None ,
2693+ bmm1_scale : Optional [float ] = 1.0 ,
2694+ bmm2_scale : Optional [float ] = 1.0 ,
2695+ bmm1_scale_log2_tensor : Optional [torch .Tensor ] = None ,
2696+ bmm2_scale_tensor : Optional [torch .Tensor ] = None ,
2697+ sinks : Optional [List [torch .Tensor ]] = None ,
2698+ enable_pdl : bool = None ,
2699+ ) -> torch .Tensor :
2700+ """
2701+ Parameters:
2702+ query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
2703+ kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
2704+ workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
2705+ qk_nope_head_dim: qk_nope_head_dim, must be 128
2706+ kv_lora_rank: kv_lora_rank, must be 512
2707+ qk_rope_head_dim: qk_rope_head_dim, must be 64
2708+ block_tables: page_table of kv cache, [batch_size, num_pages]
2709+ seq_lens: query_len
2710+ max_seq_len: max sequence length for kv_cache
2711+ out: output tensor, if not provided, will be allocated internally
2712+ bmm1_scale: fused scale for mla bmm1 input.
2713+ bmm2_scale: fused scale for mla bmm2 input.
2714+ bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
2715+ bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
2716+ sinks: additional value per head in the denominator of the softmax.
25512717
25522718 Note:
25532719 In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
@@ -2568,14 +2734,20 @@ def trtllm_batch_decode_with_kv_cache_mla(
25682734 When both are provided, the dynamic scale factor tensors will be used.
25692735 """
25702736 enable_pdl = device_support_pdl (query .device ) if enable_pdl is None else enable_pdl
2571- run_func = get_trtllm_gen_fmha_module ().trtllm_paged_attention_decode
25722737 sm_count = get_device_sm_count (query .device )
25732738
25742739 block_size = kv_cache .size (- 2 )
2575- if (
2576- block_size != 32 and block_size != 64
2577- ): # todo(Yingyi): add support for more block sizes?
2578- raise ValueError (f"Supported block_size are 32 and 64, got { block_size } " )
2740+ q_len_per_request = query .size (1 )
2741+ if q_len_per_request != 1 :
2742+ raise ValueError (
2743+ f"XQA MLA only supports q_len_per_request == 1, got { q_len_per_request } "
2744+ )
2745+ if query .dtype != torch .float8_e4m3fn or kv_cache .dtype != torch .float8_e4m3fn :
2746+ raise ValueError (
2747+ f"XQA MLA only supports fp8 tensor core operation, got { query .dtype } and { kv_cache .dtype } "
2748+ )
2749+ if sinks is not None :
2750+ raise ValueError ("XQA MLA does not support sinks" )
25792751
25802752 _check_trtllm_gen_mla_shape (
25812753 query ,
@@ -2600,33 +2772,27 @@ def trtllm_batch_decode_with_kv_cache_mla(
26002772 "out" ,
26012773 )
26022774
2603- if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None :
2604- # dynamic scale factors
2605- if query . dtype != torch . float8_e4m3fn or kv_cache . dtype != torch . float8_e4m3fn :
2606- raise ValueError (
2607- "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
2608- )
2775+ workspace_u8 = workspace_buffer . view ( torch . uint8 )
2776+ semaphore = workspace_u8 [: 8 * 1024 * 1024 ] # reserve 8MB for semaphore
2777+ scratch = workspace_u8 [ 8 * 1024 * 1024 :]
2778+ # This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same
2779+ kv_cache_new = kv_cache . squeeze ( 1 ). unsqueeze ( 2 )
2780+ seq_lens_new = seq_lens . unsqueeze ( 1 )
26092781
2610- run_func (
2611- out ,
2612- None , # fp4 output not supported in wrapper api yet.
2782+ xqa_mla (
26132783 query ,
2614- kv_cache ,
2615- kv_cache ,
2616- workspace_buffer ,
2784+ kv_cache_new ,
2785+ kv_cache_new ,
26172786 block_tables ,
2618- seq_lens ,
2619- max_seq_len ,
2620- bmm1_scale ,
2621- bmm2_scale ,
2622- - 1 , # o_sf_scale
2623- - 1 , # o_sf_vec_size
2624- 0 , # o_sf_start_index
2625- - 1 , # window_left
2626- sm_count ,
2627- enable_pdl ,
2628- workspace_buffer .numel () * workspace_buffer .element_size (),
2629- sinks ,
2787+ seq_lens_new ,
2788+ out ,
2789+ scratch ,
2790+ semaphore ,
2791+ block_size ,
2792+ q_scale = bmm1_scale ,
2793+ kv_scale = bmm2_scale ,
2794+ sm_count = sm_count ,
2795+ enable_pdl = enable_pdl ,
26302796 )
26312797
26322798 return out
0 commit comments