From bb955d2a5b8cfe064a252e56d43daa1302eb5a2c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Oct 2025 06:17:45 +0000 Subject: [PATCH 01/21] [FP8-KV-CACHE] Init vers. --- .../srt/layers/attention/aiter_backend.py | 91 ++++++++++++++++++- .../layers/quantization/quark/quark_moe.py | 15 +-- .../sglang/srt/model_executor/model_runner.py | 10 +- 3 files changed, 105 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index dafe5ee19c4a..f66fb15bdb6c 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -38,6 +38,15 @@ from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + scaled_fp8_quant, +) + +import math + +from aiter import get_mla_metadata_v1 + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() @@ -605,6 +614,8 @@ def forward_extend( kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices qo_indptr = self.forward_metadata.qo_indptr + #K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) + #V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).to(torch.bfloat16) K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) kv_lora_rank = V_Buffer.shape[-1] @@ -619,7 +630,8 @@ def forward_extend( and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() ): - if kv_indices.shape[0] == 0: + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + if kv_indices.shape[0] == 0 or extend_no_prefix: o = flash_attn_varlen_func( q, k, @@ -645,6 +657,10 @@ def forward_extend( k_prefix, v_prefix = torch.split( kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 ) + + #if self.kv_cache_dtype == fp8_dtype: + # k_pe = k_pe.to(torch.bfloat16) + k_prefix = torch.cat( [ k_prefix, @@ -805,9 +821,74 @@ def forward_decode( ) if self.use_mla: + #k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = layer.tp_q_head_num + max_qo_tiles_per_batch = int(math.ceil(self.forward_metadata.max_q_len * nhead / 128)) + + + batch_size = forward_batch.batch_size + + work_meta_data = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + page_size = 1 + nhead_kv = 1 + mtp = 1 + + split_params = { + "kv_granularity": max(page_size, 16), + "max_seqlen_qo": self.forward_metadata.max_q_len, + "uni_seqlen_qo": mtp, + "fast_mode": 1, + } + + meta = get_mla_metadata_v1( + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + nhead // nhead_kv, + nhead_kv, + True, + work_meta_data, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + split_params=split_params + ) + + + if self.kv_cache_dtype == fp8_dtype: + q_input, q_scale = scaled_fp8_quant( + q, + ) + else: + q_input = q + mla_decode_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), self.forward_metadata.qo_indptr, @@ -817,6 +898,12 @@ def forward_decode( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index d1ad13f4810e..1d52fd46d945 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -189,12 +189,15 @@ def apply( torch.float32 ) # aiter's moe_sorting requires topk_weights to be FP32 - if hasattr(torch, "float4_e2m1fn_x2"): - w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) - w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) - else: - w13_weight = layer.w13_weight - w2_weight = layer.w2_weight + #if hasattr(torch, "float4_e2m1fn_x2"): + # w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + # w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + #else: + # w13_weight = layer.w13_weight + # w2_weight = layer.w2_weight + + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight output = fused_moe( x, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 39ee02aaf673..f2a06ad394be 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -161,6 +161,10 @@ FlattenedTensorMetadata, ) +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, +) + MLA_ATTENTION_BACKENDS = [ "aiter", "flashinfer", @@ -1558,19 +1562,19 @@ def init_memory_pool( and kv_cache_quant_algo.upper() == "FP8" ): if _is_hip: - self.kv_cache_dtype = torch.float8_e4m3fnuz + self.kv_cache_dtype = fp8_dtype else: self.kv_cache_dtype = torch.float8_e4m3fn else: self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": if _is_hip: # Using natively supported format - self.kv_cache_dtype = torch.float8_e5m2fnuz + self.kv_cache_dtype = fp8_dtype else: self.kv_cache_dtype = torch.float8_e5m2 elif self.server_args.kv_cache_dtype == "fp8_e4m3": if _is_hip: # Using natively supported format - self.kv_cache_dtype = torch.float8_e4m3fnuz + self.kv_cache_dtype = fp8_dtype else: self.kv_cache_dtype = torch.float8_e4m3fn else: From f558e8b0c1e0d74fe1fb08687e597b72a3bdea1a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Oct 2025 13:38:27 +0000 Subject: [PATCH 02/21] [FP8 KV-CACHE] Force dtype conversion from fp8 to bflot16 for enable fp8 kv_cache --- python/sglang/srt/layers/attention/aiter_backend.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index f66fb15bdb6c..bef2982edc7e 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -614,9 +614,11 @@ def forward_extend( kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices qo_indptr = self.forward_metadata.qo_indptr - #K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) - #V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).to(torch.bfloat16) - K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + if self.kv_cache_dtype == fp8_dtype: + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) + else: + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) kv_lora_rank = V_Buffer.shape[-1] qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank @@ -658,9 +660,6 @@ def forward_extend( kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 ) - #if self.kv_cache_dtype == fp8_dtype: - # k_pe = k_pe.to(torch.bfloat16) - k_prefix = torch.cat( [ k_prefix, From 7fc664bb77dade65e2dde4517b711c21e48f548a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Oct 2025 09:04:10 +0000 Subject: [PATCH 03/21] Pass mla_decode_fwd accuracy test --- .../srt/layers/attention/aiter_backend.py | 28 ++++++++++++------- .../layers/quantization/quark/quark_moe.py | 15 ++++------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index bef2982edc7e..bee8ef197917 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -614,11 +614,8 @@ def forward_extend( kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices qo_indptr = self.forward_metadata.qo_indptr - if self.kv_cache_dtype == fp8_dtype: - K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) - else: - K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) kv_lora_rank = V_Buffer.shape[-1] qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank @@ -651,6 +648,11 @@ def forward_extend( kvc, k_pe = torch.split( K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 ) + + if self.kv_cache_dtype == fp8_dtype: + kvc = kvc.to(torch.bfloat16) + k_pe = k_pe.to(torch.bfloat16) + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] kvprefix = kvprefix.view( @@ -659,7 +661,7 @@ def forward_extend( k_prefix, v_prefix = torch.split( kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 ) - + k_prefix = torch.cat( [ k_prefix, @@ -717,7 +719,7 @@ def forward_extend( o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) mla_decode_fwd( q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + K_Buffer.view(-1, 1, 1, layer.qk_head_dim).to(torch.bfloat16), o, self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, @@ -737,7 +739,7 @@ def forward_extend( kv_indices = self.forward_metadata.kv_indices mla_prefill_fwd( q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + K_Buffer.view(-1, 1, 1, layer.qk_head_dim).to(torch.bfloat16), o, self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, @@ -875,14 +877,18 @@ def forward_decode( reduce_indptr, reduce_final_map, reduce_partial_map, - split_params=split_params + kv_granularity=max(page_size, 16), + max_seqlen_qo=self.forward_metadata.max_q_len, + uni_seqlen_qo=self.forward_metadata.max_q_len, + fast_mode=True, ) - if self.kv_cache_dtype == fp8_dtype: q_input, q_scale = scaled_fp8_quant( q, ) + q_scale = q_scale.to(torch.float) + kv_scale = torch.ones([1], dtype=torch.float, device="cuda") else: q_input = q @@ -903,8 +909,10 @@ def forward_decode( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) - k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) + #k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: self.logits_soft_cap = layer.logit_cap paged_attention_ragged( diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py index 1d52fd46d945..d1ad13f4810e 100644 --- a/python/sglang/srt/layers/quantization/quark/quark_moe.py +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -189,15 +189,12 @@ def apply( torch.float32 ) # aiter's moe_sorting requires topk_weights to be FP32 - #if hasattr(torch, "float4_e2m1fn_x2"): - # w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) - # w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) - #else: - # w13_weight = layer.w13_weight - # w2_weight = layer.w2_weight - - w13_weight = layer.w13_weight - w2_weight = layer.w2_weight + if hasattr(torch, "float4_e2m1fn_x2"): + w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) + w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) + else: + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight output = fused_moe( x, From 8245d3a0f8afebe0a9373d6e32316469a2901392 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 30 Oct 2025 07:02:54 +0000 Subject: [PATCH 04/21] Fix no scale issue for non-fp8 kv and default to use non-persistent mla_decode_forward kernel --- .../srt/layers/attention/aiter_backend.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index bee8ef197917..5b906585cc13 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -843,7 +843,6 @@ def forward_decode( device="cuda", ).fill_(-1) - reduce_indptr = torch.empty( [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" ) @@ -856,14 +855,6 @@ def forward_decode( page_size = 1 nhead_kv = 1 - mtp = 1 - - split_params = { - "kv_granularity": max(page_size, 16), - "max_seqlen_qo": self.forward_metadata.max_q_len, - "uni_seqlen_qo": mtp, - "fast_mode": 1, - } meta = get_mla_metadata_v1( self.forward_metadata.qo_indptr, @@ -891,6 +882,8 @@ def forward_decode( kv_scale = torch.ones([1], dtype=torch.float, device="cuda") else: q_input = q + q_scale = None + kv_scale = None mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), @@ -903,12 +896,12 @@ def forward_decode( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, - work_meta_data=work_meta_data, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, + #work_meta_data=work_meta_data, + #work_indptr=work_indptr, + #work_info_set=work_info_set, + #reduce_indptr=reduce_indptr, + #reduce_final_map=reduce_final_map, + #reduce_partial_map=reduce_partial_map, q_scale=q_scale, kv_scale=kv_scale, ) From 06b78dc34ebea0e7d637f95ef100d7c57ded9f45 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 31 Oct 2025 06:53:12 +0000 Subject: [PATCH 05/21] Add new env vairable to control mla decode persist kernel use or not --- .../srt/layers/attention/aiter_backend.py | 180 ++++++++++++------ 1 file changed, 126 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 5b906585cc13..64b7ec0540b6 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -43,10 +43,15 @@ scaled_fp8_quant, ) +from sglang.srt.utils import ( + get_bool_env_var, +) + import math from aiter import get_mla_metadata_v1 +_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") class WrapperDispatch(Enum): SLIDING_WINDOW = auto() @@ -61,6 +66,12 @@ class ForwardMetadata: kv_last_page_len: torch.Tensor max_q_len: int max_kv_len: Optional[int] + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None global_workspace_buffer = None @@ -370,6 +381,51 @@ def init_cuda_graph_state( device=self.device, ) + if self.use_mla and _use_mla_ps_kernel: + # for persistent mla_decode_fwd + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = ( + 1 + if self.num_draft_tokens is None + else self.num_draft_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = max_bs + + self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + self.work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + self.reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + self.reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + self.reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + else: + self.work_metadata = None + self.work_indptr = None + self.work_info_set = None + + self.reduce_indptr = None + self.reduce_final_map = None + self.reduce_partial_map = None + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -408,8 +464,31 @@ def init_forward_metadata_capture_cuda_graph( self.cuda_graph_kv_last_page_len[:bs], dim=0 ) kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + if _use_mla_ps_kernel: + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -417,6 +496,12 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) elif forward_mode.is_target_verify(): @@ -535,6 +620,35 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + if self.use_mla and _use_mla_ps_kernel: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + + max_q_len = 1 + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) + elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -822,57 +936,15 @@ def forward_decode( ) if self.use_mla: - #k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(torch.bfloat16) k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = layer.tp_q_head_num - max_qo_tiles_per_batch = int(math.ceil(self.forward_metadata.max_q_len * nhead / 128)) - - - batch_size = forward_batch.batch_size + work_meta_data = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set - work_meta_data = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) - - page_size = 1 - nhead_kv = 1 - - meta = get_mla_metadata_v1( - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - nhead // nhead_kv, - nhead_kv, - True, - work_meta_data, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=self.forward_metadata.max_q_len, - uni_seqlen_qo=self.forward_metadata.max_q_len, - fast_mode=True, - ) + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map if self.kv_cache_dtype == fp8_dtype: q_input, q_scale = scaled_fp8_quant( @@ -896,12 +968,12 @@ def forward_decode( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, - #work_meta_data=work_meta_data, - #work_indptr=work_indptr, - #work_info_set=work_info_set, - #reduce_indptr=reduce_indptr, - #reduce_final_map=reduce_final_map, - #reduce_partial_map=reduce_partial_map, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, q_scale=q_scale, kv_scale=kv_scale, ) From ac6ee4c4383428c68f62eb0e39462f105d09ce28 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Nov 2025 07:36:13 +0000 Subject: [PATCH 06/21] MTP fp8-kv accuracy pass --- .../srt/layers/attention/aiter_backend.py | 507 +++++++++++++++++- 1 file changed, 502 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 64b7ec0540b6..2d346b1b9bc6 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -184,6 +184,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_last_page_len = None max_q_len = None + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) @@ -210,6 +217,58 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_last_page_len = self.kv_last_page_len[:bs] max_q_len = 1 + if _use_mla_ps_kernel: + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = 1 + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = bs + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -217,6 +276,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_last_page_len, max_q_len, None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) elif forward_batch.forward_mode.is_draft_extend(): @@ -229,6 +294,58 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.req_to_token, ) ) + + if _use_mla_ps_kernel: + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = bs + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -237,6 +354,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.kv_last_page_len[:bs], max(forward_batch.extend_seq_lens_cpu), forward_batch.seq_lens_cpu.max().item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) else: self.indices_updater_prefill.update( @@ -286,6 +409,58 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_indices, self.req_to_token.stride(0), ) + + if _use_mla_ps_kernel: + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = draft_num + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = bs + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -294,6 +469,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.kv_last_page_len[:bs], draft_num, None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) else: self.indices_updater_prefill.update( @@ -529,6 +710,28 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = self.num_draft_tokens + if _use_mla_ps_kernel: + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -536,6 +739,12 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) else: seq_lens_sum = seq_lens.sum().item() @@ -579,6 +788,28 @@ def init_forward_metadata_capture_cuda_graph( ) kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = num_tokens_per_bs + + if _use_mla_ps_kernel: + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -586,6 +817,12 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -649,6 +886,7 @@ def init_forward_metadata_replay_cuda_graph( fast_mode=True, ) + elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -672,6 +910,30 @@ def init_forward_metadata_replay_cuda_graph( kv_indices, self.req_to_token.stride(0), ) + + if self.use_mla and _use_mla_ps_kernel: + max_q_len = self.num_draft_tokens + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) elif forward_mode.is_draft_extend(): seq_lens = seq_lens[:bs] accept_lens = spec_info.accept_length[:bs] @@ -689,6 +951,30 @@ def init_forward_metadata_replay_cuda_graph( kv_indices, self.req_to_token.stride(0), ) + + if self.use_mla and _use_mla_ps_kernel: + max_q_len = torch.max(accept_lens).item() + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) else: raise ValueError("Invalid forward mode") @@ -831,9 +1117,82 @@ def forward_extend( return o elif forward_batch.forward_mode.is_target_verify(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = ( + 1 + if self.num_draft_tokens is None + else self.num_draft_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = forward_batch.batch_size + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + if self.kv_cache_dtype == fp8_dtype: + #q_input, q_scale = scaled_fp8_quant( + # q.view(-1, layer.tp_q_head_num*layer.qk_head_dim), + #) + #q_scale = q_scale.to(torch.float) + + q_input = q.to(fp8_dtype) + q_scale = torch.ones([1], dtype=torch.float, device="cuda") + kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + else: + q_input = q + q_scale = None + kv_scale = None + + #q_input = q + #q_scale = None + #k_scale = None + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=self.forward_metadata.max_q_len, + uni_seqlen_qo=self.forward_metadata.max_q_len, + fast_mode=True, + ) + mla_decode_fwd( - q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim).to(torch.bfloat16), + q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, @@ -842,18 +1201,96 @@ def forward_extend( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o elif forward_batch.forward_mode.is_draft_extend(): + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_seqlen_qo = ( + 1 + if self.num_draft_tokens is None + else self.num_draft_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + batch_size = forward_batch.batch_size + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + if self.kv_cache_dtype == fp8_dtype: + #q_input, q_scale = scaled_fp8_quant( + # q.view(-1, layer.tp_q_head_num*layer.qk_head_dim), + #) + #q_scale = q_scale.to(torch.float) + q_input = q.to(fp8_dtype) + q_scale = torch.ones([1], dtype=torch.float, device="cuda") + + kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + else: + q_input = q + q_scale = None + kv_scale = None + + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=self.forward_metadata.max_q_len, + uni_seqlen_qo=self.forward_metadata.max_q_len, + fast_mode=True, + ) + o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) causal = True sliding_window_size = -1 kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices - mla_prefill_fwd( - q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim).to(torch.bfloat16), + mla_decode_fwd( + q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, @@ -862,6 +1299,14 @@ def forward_extend( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o @@ -946,6 +1391,55 @@ def forward_decode( reduce_final_map = self.forward_metadata.reduce_final_map reduce_partial_map = self.forward_metadata.reduce_partial_map + #gpu = torch.cuda.current_device() + #device_properties = torch.cuda.get_device_properties(gpu) + #cu_num = device_properties.multi_processor_count + + #nhead = layer.tp_q_head_num + #max_qo_tiles_per_batch = int(math.ceil(self.forward_metadata.max_q_len * nhead / 128)) + + + #batch_size = forward_batch.batch_size + + #work_meta_data = torch.empty([10], dtype=torch.uint64, device="cuda") + #work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + #work_info_set = torch.empty( + # [batch_size * max_qo_tiles_per_batch * cu_num, 8], + # dtype=torch.int32, + # device="cuda", + #).fill_(-1) + + #reduce_indptr = torch.empty( + # [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + #) + #reduce_final_map = torch.empty( + # [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + #) + #reduce_partial_map = torch.empty( + # [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + #) + + #page_size = 1 + #nhead_kv = 1 + + #meta = get_mla_metadata_v1( + # self.forward_metadata.qo_indptr, + # self.forward_metadata.kv_indptr, + # nhead // nhead_kv, + # nhead_kv, + # True, + # work_meta_data, + # work_info_set, + # work_indptr, + # reduce_indptr, + # reduce_final_map, + # reduce_partial_map, + # kv_granularity=max(page_size, 16), + # max_seqlen_qo=self.forward_metadata.max_q_len, + # uni_seqlen_qo=self.forward_metadata.max_q_len, + # fast_mode=True, + #) + if self.kv_cache_dtype == fp8_dtype: q_input, q_scale = scaled_fp8_quant( q, @@ -957,6 +1451,9 @@ def forward_decode( q_scale = None kv_scale = None + #q_input = q + #q_scale = None + mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), From 7be57a8f6367ecf63926cc1fdd3687eb77c1d8dc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Nov 2025 08:18:26 +0000 Subject: [PATCH 07/21] Fix GPU fault when using persist mla_decode_fwd kernel in MTP --- .../srt/layers/attention/aiter_backend.py | 80 +------------------ 1 file changed, 2 insertions(+), 78 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 2d346b1b9bc6..1914b1383461 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -177,6 +177,7 @@ def __init__( def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size kv_indptr = self.kv_indptr spec_info = forward_batch.spec_info @@ -617,6 +618,7 @@ def init_forward_metadata_capture_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], ): + if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None @@ -857,35 +859,6 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - if self.use_mla and _use_mla_ps_kernel: - qo_indptr = self.qo_indptr_[: bs + 1] - qo_indptr[1 : bs + 1] = torch.cumsum( - self.cuda_graph_kv_last_page_len[:bs], dim=0 - ) - - max_q_len = 1 - - nhead_kv = 1 - page_size = 1 - - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=True, - ) - elif forward_mode.is_target_verify(): bs = len(req_pool_indices) @@ -1391,55 +1364,6 @@ def forward_decode( reduce_final_map = self.forward_metadata.reduce_final_map reduce_partial_map = self.forward_metadata.reduce_partial_map - #gpu = torch.cuda.current_device() - #device_properties = torch.cuda.get_device_properties(gpu) - #cu_num = device_properties.multi_processor_count - - #nhead = layer.tp_q_head_num - #max_qo_tiles_per_batch = int(math.ceil(self.forward_metadata.max_q_len * nhead / 128)) - - - #batch_size = forward_batch.batch_size - - #work_meta_data = torch.empty([10], dtype=torch.uint64, device="cuda") - #work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - #work_info_set = torch.empty( - # [batch_size * max_qo_tiles_per_batch * cu_num, 8], - # dtype=torch.int32, - # device="cuda", - #).fill_(-1) - - #reduce_indptr = torch.empty( - # [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - #) - #reduce_final_map = torch.empty( - # [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - #) - #reduce_partial_map = torch.empty( - # [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - #) - - #page_size = 1 - #nhead_kv = 1 - - #meta = get_mla_metadata_v1( - # self.forward_metadata.qo_indptr, - # self.forward_metadata.kv_indptr, - # nhead // nhead_kv, - # nhead_kv, - # True, - # work_meta_data, - # work_info_set, - # work_indptr, - # reduce_indptr, - # reduce_final_map, - # reduce_partial_map, - # kv_granularity=max(page_size, 16), - # max_seqlen_qo=self.forward_metadata.max_q_len, - # uni_seqlen_qo=self.forward_metadata.max_q_len, - # fast_mode=True, - #) - if self.kv_cache_dtype == fp8_dtype: q_input, q_scale = scaled_fp8_quant( q, From f3864a958d074ddb9d856d5b428d842427b26c92 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Nov 2025 03:30:14 +0000 Subject: [PATCH 08/21] Code refactor --- .../srt/layers/attention/aiter_backend.py | 203 +++++------------- 1 file changed, 52 insertions(+), 151 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 1914b1383461..bc608657a77b 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -859,6 +859,34 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + #if self.use_mla and _use_mla_ps_kernel: + # qo_indptr = self.qo_indptr_[: bs + 1] + # qo_indptr[1 : bs + 1] = torch.cumsum( + # self.cuda_graph_kv_last_page_len[:bs], dim=0 + # ) + + # max_q_len = 1 + + # nhead_kv = 1 + # page_size = 1 + + # meta = get_mla_metadata_v1( + # qo_indptr, + # kv_indptr, + # self.num_head // nhead_kv, + # nhead_kv, + # True, + # self.work_metadata, + # self.work_info_set, + # self.work_indptr, + # self.reduce_indptr, + # self.reduce_final_map, + # self.reduce_partial_map, + # kv_granularity=max(page_size, 16), + # max_seqlen_qo=max_q_len, + # uni_seqlen_qo=max_q_len, + # fast_mode=True, + # ) elif forward_mode.is_target_verify(): bs = len(req_pool_indices) @@ -1090,46 +1118,10 @@ def forward_extend( return o elif forward_batch.forward_mode.is_target_verify(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - nhead = self.num_head - - max_seqlen_qo = ( - 1 - if self.num_draft_tokens is None - else self.num_draft_tokens - ) - - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = forward_batch.batch_size - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) + if self.kv_cache_dtype == fp8_dtype: - #q_input, q_scale = scaled_fp8_quant( - # q.view(-1, layer.tp_q_head_num*layer.qk_head_dim), - #) - #q_scale = q_scale.to(torch.float) - q_input = q.to(fp8_dtype) q_scale = torch.ones([1], dtype=torch.float, device="cuda") kv_scale = torch.ones([1], dtype=torch.float, device="cuda") @@ -1138,30 +1130,13 @@ def forward_extend( q_scale = None kv_scale = None - #q_input = q - #q_scale = None - #k_scale = None - - nhead_kv = 1 - page_size = 1 + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=self.forward_metadata.max_q_len, - uni_seqlen_qo=self.forward_metadata.max_q_len, - fast_mode=True, - ) + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), @@ -1183,48 +1158,12 @@ def forward_extend( q_scale=q_scale, kv_scale=kv_scale, ) - K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + #K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o elif forward_batch.forward_mode.is_draft_extend(): - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = self.num_head - - max_seqlen_qo = ( - 1 - if self.num_draft_tokens is None - else self.num_draft_tokens - ) - - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = forward_batch.batch_size - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) + o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) if self.kv_cache_dtype == fp8_dtype: - #q_input, q_scale = scaled_fp8_quant( - # q.view(-1, layer.tp_q_head_num*layer.qk_head_dim), - #) - #q_scale = q_scale.to(torch.float) q_input = q.to(fp8_dtype) q_scale = torch.ones([1], dtype=torch.float, device="cuda") @@ -1234,33 +1173,14 @@ def forward_extend( q_scale = None kv_scale = None + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set - nhead_kv = 1 - page_size = 1 + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=self.forward_metadata.max_q_len, - uni_seqlen_qo=self.forward_metadata.max_q_len, - fast_mode=True, - ) - - o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) - causal = True - sliding_window_size = -1 - kv_indptr = self.forward_metadata.kv_indptr - kv_indices = self.forward_metadata.kv_indices mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), K_Buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1281,27 +1201,8 @@ def forward_extend( q_scale=q_scale, kv_scale=kv_scale, ) - K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + #K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o - # self.extend_attention_fwd( - # q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - # k.contiguous(), - # v.contiguous(), - # o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), - # forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - # self.forward_metadata.qo_indptr, - # kv_indptr, - # kv_indices, - # None, - # causal, - # None, - # self.forward_metadata.max_q_len, - # layer.scaling, - # layer.logit_cap, - # sliding_window_size, - # ) - # return o else: raise ValueError( f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}" @@ -1365,19 +1266,19 @@ def forward_decode( reduce_partial_map = self.forward_metadata.reduce_partial_map if self.kv_cache_dtype == fp8_dtype: - q_input, q_scale = scaled_fp8_quant( - q, - ) - q_scale = q_scale.to(torch.float) + #q_input, q_scale = scaled_fp8_quant( + # q, + #) + #q_scale = q_scale.to(torch.float) + q_input = q.to(fp8_dtype) + q_scale = torch.ones([1], dtype=torch.float, device="cuda") kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + else: q_input = q q_scale = None kv_scale = None - #q_input = q - #q_scale = None - mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1395,7 +1296,7 @@ def forward_decode( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=q_scale, + q_scale=kv_scale, kv_scale=kv_scale, ) #k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) From 5ce2aa1f47bc2e71ee9c09552b2f87211c6f3ddc Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Nov 2025 07:10:10 +0000 Subject: [PATCH 09/21] Refactor code v2 --- .../srt/layers/attention/aiter_backend.py | 250 ++++++++---------- 1 file changed, 108 insertions(+), 142 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index bc608657a77b..fd4f1f3798be 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -92,6 +92,8 @@ def __init__( extend_attention_fwd, ) + self.page_size = model_runner.server_args.page_size + self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) self.device = model_runner.device @@ -174,6 +176,35 @@ def __init__( self.enable_dp_attention = is_dp_attention_enabled() + def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + + nhead = self.num_head + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) + + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + work_info_set = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + + reduce_indptr = torch.empty( + [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + ) + reduce_final_map = torch.empty( + [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + reduce_partial_map = torch.empty( + [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + ) + + return work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -219,39 +250,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_q_len = 1 if _use_mla_ps_kernel: - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = self.num_head - - max_seqlen_qo = 1 - - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = bs - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) - nhead_kv = 1 page_size = 1 + max_seqlen_qo = 1 + work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -297,39 +301,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) if _use_mla_ps_kernel: - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = self.num_head - - max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) - - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = bs - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) - nhead_kv = 1 page_size = 1 + max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) + work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -347,6 +324,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): uni_seqlen_qo=max_q_len, fast_mode=True, ) + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -411,40 +389,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.req_to_token.stride(0), ) - if _use_mla_ps_kernel: - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = self.num_head - - max_seqlen_qo = draft_num - - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = bs - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) - + #if _use_mla_ps_kernel: + if self.kv_cache_dtype == fp8_dtype: nhead_kv = 1 page_size = 1 + max_seqlen_qo = draft_num + work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -458,10 +410,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map, reduce_partial_map, kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, fast_mode=True, ) + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -563,41 +516,16 @@ def init_cuda_graph_state( device=self.device, ) - if self.use_mla and _use_mla_ps_kernel: + #if self.use_mla and _use_mla_ps_kernel: + if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype): # for persistent mla_decode_fwd - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - - nhead = self.num_head - max_seqlen_qo = ( 1 if self.num_draft_tokens is None else self.num_draft_tokens ) - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - batch_size = max_bs - - self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") - self.work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, - device="cuda", - ).fill_(-1) - - self.reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" - ) - self.reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" - ) - self.reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" - ) + self.work_metadata, self.work_indptr, self.work_info_set, self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) else: self.work_metadata = None @@ -619,6 +547,14 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInput], ): + work_metadata = None + work_info_set = None + work_indptr = None + + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None @@ -672,6 +608,15 @@ def init_forward_metadata_capture_cuda_graph( fast_mode=True, ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -679,12 +624,12 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, - work_metadata=self.work_metadata, - work_info_set=self.work_info_set, - work_indptr=self.work_indptr, - reduce_indptr=self.reduce_indptr, - reduce_final_map=self.reduce_final_map, - reduce_partial_map=self.reduce_partial_map, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) elif forward_mode.is_target_verify(): @@ -712,7 +657,8 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = self.num_draft_tokens - if _use_mla_ps_kernel: + #if _use_mla_ps_kernel: + if self.kv_cache_dtype == fp8_dtype: nhead_kv = 1 page_size = 1 @@ -734,6 +680,14 @@ def init_forward_metadata_capture_cuda_graph( fast_mode=True, ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -741,13 +695,14 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, - work_metadata=self.work_metadata, - work_info_set=self.work_info_set, - work_indptr=self.work_indptr, - reduce_indptr=self.reduce_indptr, - reduce_final_map=self.reduce_final_map, - reduce_partial_map=self.reduce_partial_map, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) + else: seq_lens_sum = seq_lens.sum().item() self.indices_updater_prefill.update( @@ -812,6 +767,15 @@ def init_forward_metadata_capture_cuda_graph( uni_seqlen_qo=max_q_len, fast_mode=True, ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -819,13 +783,14 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len, max_q_len, None, - work_metadata=self.work_metadata, - work_info_set=self.work_info_set, - work_indptr=self.work_indptr, - reduce_indptr=self.reduce_indptr, - reduce_final_map=self.reduce_final_map, - reduce_partial_map=self.reduce_partial_map, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) + else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -912,7 +877,8 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) - if self.use_mla and _use_mla_ps_kernel: + #if self.use_mla and _use_mla_ps_kernel: + if self.use_mla and self.kv_cache_dtype == fp8_dtype: max_q_len = self.num_draft_tokens nhead_kv = 1 @@ -1139,7 +1105,7 @@ def forward_extend( reduce_partial_map = self.forward_metadata.reduce_partial_map mla_decode_fwd( - q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + q_input, K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, From 7848b53715a5ba58ed0314c1900d5dd74002b575 Mon Sep 17 00:00:00 2001 From: wunhuang Date: Thu, 13 Nov 2025 06:22:54 +0000 Subject: [PATCH 10/21] Code refactor v3 --- .../srt/layers/attention/aiter_backend.py | 79 +++++++++---------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index fd4f1f3798be..caf04a8e5473 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -208,7 +208,6 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" - bs = forward_batch.batch_size kv_indptr = self.kv_indptr spec_info = forward_batch.spec_info @@ -389,8 +388,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.req_to_token.stride(0), ) - #if _use_mla_ps_kernel: - if self.kv_cache_dtype == fp8_dtype: + #if self.kv_cache_dtype == fp8_dtype: + if _use_mla_ps_kernel: nhead_kv = 1 page_size = 1 @@ -516,8 +515,8 @@ def init_cuda_graph_state( device=self.device, ) - #if self.use_mla and _use_mla_ps_kernel: - if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype): + #if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype): + if self.use_mla and _use_mla_ps_kernel: # for persistent mla_decode_fwd max_seqlen_qo = ( 1 @@ -583,7 +582,6 @@ def init_forward_metadata_capture_cuda_graph( self.cuda_graph_kv_last_page_len[:bs], dim=0 ) kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] - max_q_len = 1 if _use_mla_ps_kernel: @@ -657,8 +655,8 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = self.num_draft_tokens - #if _use_mla_ps_kernel: - if self.kv_cache_dtype == fp8_dtype: + #if self.kv_cache_dtype == fp8_dtype: + if _use_mla_ps_kernel: nhead_kv = 1 page_size = 1 @@ -702,7 +700,6 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, ) - else: seq_lens_sum = seq_lens.sum().item() self.indices_updater_prefill.update( @@ -790,7 +787,6 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, ) - else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -824,34 +820,35 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - #if self.use_mla and _use_mla_ps_kernel: - # qo_indptr = self.qo_indptr_[: bs + 1] - # qo_indptr[1 : bs + 1] = torch.cumsum( - # self.cuda_graph_kv_last_page_len[:bs], dim=0 - # ) - - # max_q_len = 1 - - # nhead_kv = 1 - # page_size = 1 - - # meta = get_mla_metadata_v1( - # qo_indptr, - # kv_indptr, - # self.num_head // nhead_kv, - # nhead_kv, - # True, - # self.work_metadata, - # self.work_info_set, - # self.work_indptr, - # self.reduce_indptr, - # self.reduce_final_map, - # self.reduce_partial_map, - # kv_granularity=max(page_size, 16), - # max_seqlen_qo=max_q_len, - # uni_seqlen_qo=max_q_len, - # fast_mode=True, - # ) + if self.use_mla and (_use_mla_ps_kernel and self.num_draft_tokens == None): + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + + max_q_len = 1 + + nhead_kv = 1 + page_size = 1 + + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=True, + ) elif forward_mode.is_target_verify(): bs = len(req_pool_indices) @@ -877,8 +874,8 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) - #if self.use_mla and _use_mla_ps_kernel: - if self.use_mla and self.kv_cache_dtype == fp8_dtype: + #if self.use_mla and self.kv_cache_dtype == fp8_dtype: + if self.use_mla and _use_mla_ps_kernel: max_q_len = self.num_draft_tokens nhead_kv = 1 @@ -981,7 +978,6 @@ def forward_extend( kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices qo_indptr = self.forward_metadata.qo_indptr - K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) kv_lora_rank = V_Buffer.shape[-1] @@ -1028,7 +1024,6 @@ def forward_extend( k_prefix, v_prefix = torch.split( kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 ) - k_prefix = torch.cat( [ k_prefix, From a8463cc60207fac0cedea4ec73911e54053e400c Mon Sep 17 00:00:00 2001 From: wunhuang Date: Mon, 17 Nov 2025 01:56:24 +0000 Subject: [PATCH 11/21] Format code --- .../srt/layers/attention/aiter_backend.py | 110 +++++++++++------- .../sglang/srt/model_executor/model_runner.py | 5 +- 2 files changed, 69 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index caf04a8e5473..08f33264ece5 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -36,23 +36,17 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) -from sglang.srt.configs.model_config import AttentionArch - -from sglang.srt.layers.quantization.fp8_kernel import ( - fp8_dtype, - scaled_fp8_quant, -) - -from sglang.srt.utils import ( - get_bool_env_var, -) - import math from aiter import get_mla_metadata_v1 +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype +from sglang.srt.utils import get_bool_env_var + _use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() CROSS_ATTENTION = auto() @@ -180,11 +174,11 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): gpu = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(gpu) cu_num = device_properties.multi_processor_count - + nhead = self.num_head - + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - + work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") work_info_set = torch.empty( @@ -192,7 +186,7 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): dtype=torch.int32, device="cuda", ).fill_(-1) - + reduce_indptr = torch.empty( [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" ) @@ -200,10 +194,19 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" ) reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device="cuda" + [batch_size * max_qo_tiles_per_batch * cu_num], + dtype=torch.int32, + device="cuda", + ) + + return ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, ) - - return work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -253,7 +256,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_size = 1 max_seqlen_qo = 1 - work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) meta = get_mla_metadata_v1( qo_indptr, @@ -304,7 +314,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_size = 1 max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) - work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) meta = get_mla_metadata_v1( qo_indptr, @@ -388,13 +405,20 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.req_to_token.stride(0), ) - #if self.kv_cache_dtype == fp8_dtype: + # if self.kv_cache_dtype == fp8_dtype: if _use_mla_ps_kernel: nhead_kv = 1 page_size = 1 max_seqlen_qo = draft_num - work_metadata, work_indptr, work_info_set, reduce_indptr, reduce_final_map, reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) meta = get_mla_metadata_v1( qo_indptr, @@ -515,16 +539,21 @@ def init_cuda_graph_state( device=self.device, ) - #if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype): + # if self.use_mla and (_use_mla_ps_kernel or self.kv_cache_dtype == fp8_dtype): if self.use_mla and _use_mla_ps_kernel: # for persistent mla_decode_fwd max_seqlen_qo = ( - 1 - if self.num_draft_tokens is None - else self.num_draft_tokens + 1 if self.num_draft_tokens is None else self.num_draft_tokens ) - self.work_metadata, self.work_indptr, self.work_info_set, self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) + ( + self.work_metadata, + self.work_indptr, + self.work_info_set, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) else: self.work_metadata = None @@ -609,12 +638,11 @@ def init_forward_metadata_capture_cuda_graph( work_metadata = self.work_metadata work_info_set = self.work_info_set work_indptr = self.work_indptr - + reduce_indptr = self.reduce_indptr reduce_final_map = self.reduce_final_map reduce_partial_map = self.reduce_partial_map - self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, @@ -655,7 +683,7 @@ def init_forward_metadata_capture_cuda_graph( kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = self.num_draft_tokens - #if self.kv_cache_dtype == fp8_dtype: + # if self.kv_cache_dtype == fp8_dtype: if _use_mla_ps_kernel: nhead_kv = 1 page_size = 1 @@ -681,7 +709,7 @@ def init_forward_metadata_capture_cuda_graph( work_metadata = self.work_metadata work_info_set = self.work_info_set work_indptr = self.work_indptr - + reduce_indptr = self.reduce_indptr reduce_final_map = self.reduce_final_map reduce_partial_map = self.reduce_partial_map @@ -768,7 +796,7 @@ def init_forward_metadata_capture_cuda_graph( work_metadata = self.work_metadata work_info_set = self.work_info_set work_indptr = self.work_indptr - + reduce_indptr = self.reduce_indptr reduce_final_map = self.reduce_final_map reduce_partial_map = self.reduce_partial_map @@ -874,7 +902,7 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) - #if self.use_mla and self.kv_cache_dtype == fp8_dtype: + # if self.use_mla and self.kv_cache_dtype == fp8_dtype: if self.use_mla and _use_mla_ps_kernel: max_q_len = self.num_draft_tokens @@ -1013,7 +1041,7 @@ def forward_extend( ) if self.kv_cache_dtype == fp8_dtype: - kvc = kvc.to(torch.bfloat16) + kvc = kvc.to(torch.bfloat16) k_pe = k_pe.to(torch.bfloat16) kvprefix = layer.kv_b_proj(kvc.contiguous())[0] @@ -1080,8 +1108,6 @@ def forward_extend( elif forward_batch.forward_mode.is_target_verify(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) - - if self.kv_cache_dtype == fp8_dtype: q_input = q.to(fp8_dtype) q_scale = torch.ones([1], dtype=torch.float, device="cuda") @@ -1119,7 +1145,7 @@ def forward_extend( q_scale=q_scale, kv_scale=kv_scale, ) - #K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + # K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o elif forward_batch.forward_mode.is_draft_extend(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) @@ -1162,7 +1188,7 @@ def forward_extend( q_scale=q_scale, kv_scale=kv_scale, ) - #K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) + # K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o else: raise ValueError( @@ -1227,10 +1253,10 @@ def forward_decode( reduce_partial_map = self.forward_metadata.reduce_partial_map if self.kv_cache_dtype == fp8_dtype: - #q_input, q_scale = scaled_fp8_quant( + # q_input, q_scale = scaled_fp8_quant( # q, - #) - #q_scale = q_scale.to(torch.float) + # ) + # q_scale = q_scale.to(torch.float) q_input = q.to(fp8_dtype) q_scale = torch.ones([1], dtype=torch.float, device="cuda") kv_scale = torch.ones([1], dtype=torch.float, device="cuda") @@ -1260,7 +1286,7 @@ def forward_decode( q_scale=kv_scale, kv_scale=kv_scale, ) - #k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) + # k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: self.logits_soft_cap = layer.logit_cap paged_attention_ragged( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cbf0c6bd4316..6460507b6328 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -86,6 +86,7 @@ initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager @@ -162,10 +163,6 @@ FlattenedTensorMetadata, ) -from sglang.srt.layers.quantization.fp8_kernel import ( - fp8_dtype, -) - MLA_ATTENTION_BACKENDS = [ "aiter", "flashinfer", From 82294e978ad5de1c87ca797422e4133b4cc4d62e Mon Sep 17 00:00:00 2001 From: wunhuang Date: Mon, 17 Nov 2025 07:52:41 +0000 Subject: [PATCH 12/21] According to the q type to convert kv cache type for following computation for paged_attention_ragged and mha_batch_prefill_func --- .../srt/layers/attention/aiter_backend.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 08f33264ece5..633540d90d09 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1041,8 +1041,10 @@ def forward_extend( ) if self.kv_cache_dtype == fp8_dtype: - kvc = kvc.to(torch.bfloat16) - k_pe = k_pe.to(torch.bfloat16) + dtype = q.dtype + + kvc = kvc.to(dtype) + k_pe = k_pe.to(dtype) kvprefix = layer.kv_b_proj(kvc.contiguous())[0] @@ -1201,6 +1203,11 @@ def forward_extend( bs0 = forward_batch.batch_size + 1 + if self.kv_cache_dtype == fp8_dtype: + dtype = q.dtype + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + o = mha_batch_prefill_func( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache, @@ -1289,16 +1296,23 @@ def forward_decode( # k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: self.logits_soft_cap = layer.logit_cap + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + + if self.kv_cache_dtype == fp8_dtype: + dtype = q.dtype + + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + paged_attention_ragged( o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), self.workspace_buffer, q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( - -1, 1, layer.tp_k_head_num, layer.qk_head_dim - ), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( - -1, 1, layer.tp_v_head_num, layer.v_head_dim - ), + k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim), + v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim), self.scale, self.forward_metadata.kv_indptr, self.forward_metadata.kv_indices, From e4d27b5923e826f6694b1df4ed533c2db581d5d0 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 21 Nov 2025 14:08:13 +0000 Subject: [PATCH 13/21] Fix the MTP accuracy issues with using aiter persisent v2 mla_decode kernel --- .../srt/layers/attention/aiter_backend.py | 308 +++++++++++------- 1 file changed, 195 insertions(+), 113 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 633540d90d09..7c808b2f0905 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -27,6 +27,7 @@ try: from aiter import ( flash_attn_varlen_func, + get_mla_metadata_info_v1, mha_batch_prefill_func, paged_attention_ragged, ) @@ -36,7 +37,6 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) -import math from aiter import get_mla_metadata_v1 @@ -66,6 +66,8 @@ class ForwardMetadata: reduce_indptr: Optional[torch.Tensor] = None reduce_final_map: Optional[torch.Tensor] = None reduce_partial_map: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + num_kv_splits_indptr: Optional[torch.Tensor] = None global_workspace_buffer = None @@ -171,32 +173,48 @@ def __init__( self.enable_dp_attention = is_dp_attention_enabled() def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - nhead = self.num_head + dtype = self.kv_cache_dtype + + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + batch_size, + max_seqlen_qo, + nhead, + dtype, + dtype, + is_sparse=False, + fast_mode=True, + intra_batch_mode=True, + ) - max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * nhead / 128)) - - work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + # aiter implementation + # the tensor's meaning please refer aiter/ops/attention.py + work_metadata = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device="cuda" + ) + work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device="cuda" + ) work_info_set = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num, 8], - dtype=torch.int32, + work_info_set_size, + dtype=work_info_set_type, device="cuda", - ).fill_(-1) - + ) reduce_indptr = torch.empty( - [batch_size * max_qo_tiles_per_batch + 1], dtype=torch.int32, device="cuda" + reduce_indptr_size, dtype=reduce_indptr_type, device="cuda" ) reduce_final_map = torch.empty( - [batch_size * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + reduce_final_map_size, dtype=reduce_final_map_type, device="cuda" ) reduce_partial_map = torch.empty( - [batch_size * max_qo_tiles_per_batch * cu_num], - dtype=torch.int32, - device="cuda", + reduce_partial_map_size, dtype=reduce_partial_map_type, device="cuda" ) return ( @@ -208,6 +226,15 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): reduce_partial_map, ) + def make_split_kv_buffer(self, bs): + num_kv_splits = 32 + + num_kv_splits_indptr = torch.arange( + 0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda" + ) + + return num_kv_splits, num_kv_splits_indptr + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -225,6 +252,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map = None reduce_partial_map = None + num_kv_splits = None + num_kv_splits_indptr = None + if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) @@ -251,7 +281,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_last_page_len = self.kv_last_page_len[:bs] max_q_len = 1 - if _use_mla_ps_kernel: + if _use_mla_ps_kernel and self.num_draft_tokens == None: nhead_kv = 1 page_size = 1 @@ -265,6 +295,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -280,7 +312,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_granularity=max(page_size, 16), max_seqlen_qo=max_q_len, uni_seqlen_qo=max_q_len, - fast_mode=True, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) self.forward_metadata = ForwardMetadata( @@ -296,6 +330,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) elif forward_batch.forward_mode.is_draft_extend(): @@ -323,6 +359,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -336,9 +374,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map, reduce_partial_map, kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=True, + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) self.forward_metadata = ForwardMetadata( @@ -355,6 +395,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) else: self.indices_updater_prefill.update( @@ -420,6 +462,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -435,7 +479,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_granularity=max(page_size, 16), max_seqlen_qo=max_seqlen_qo, uni_seqlen_qo=max_seqlen_qo, - fast_mode=True, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) self.forward_metadata = ForwardMetadata( @@ -452,6 +498,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) else: self.indices_updater_prefill.update( @@ -575,6 +623,9 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInput], ): + num_kv_splits = None + num_kv_splits_indptr = None + work_metadata = None work_info_set = None work_indptr = None @@ -617,6 +668,8 @@ def init_forward_metadata_capture_cuda_graph( nhead_kv = 1 page_size = 1 + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -632,7 +685,9 @@ def init_forward_metadata_capture_cuda_graph( kv_granularity=max(page_size, 16), max_seqlen_qo=max_q_len, uni_seqlen_qo=max_q_len, - fast_mode=True, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) work_metadata = self.work_metadata @@ -649,13 +704,15 @@ def init_forward_metadata_capture_cuda_graph( qo_indptr, kv_last_page_len, max_q_len, - None, + kv_indptr[-1].item(), work_metadata=work_metadata, work_info_set=work_info_set, work_indptr=work_indptr, reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) elif forward_mode.is_target_verify(): @@ -688,6 +745,8 @@ def init_forward_metadata_capture_cuda_graph( nhead_kv = 1 page_size = 1 + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -703,7 +762,9 @@ def init_forward_metadata_capture_cuda_graph( kv_granularity=max(page_size, 16), max_seqlen_qo=max_q_len, uni_seqlen_qo=max_q_len, - fast_mode=True, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) work_metadata = self.work_metadata @@ -720,13 +781,15 @@ def init_forward_metadata_capture_cuda_graph( qo_indptr, kv_last_page_len, max_q_len, - None, + kv_indptr[-1].item(), work_metadata=work_metadata, work_info_set=work_info_set, work_indptr=work_indptr, reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) else: seq_lens_sum = seq_lens.sum().item() @@ -775,6 +838,8 @@ def init_forward_metadata_capture_cuda_graph( nhead_kv = 1 page_size = 1 + num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -790,7 +855,9 @@ def init_forward_metadata_capture_cuda_graph( kv_granularity=max(page_size, 16), max_seqlen_qo=max_q_len, uni_seqlen_qo=max_q_len, - fast_mode=True, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, ) work_metadata = self.work_metadata @@ -807,13 +874,15 @@ def init_forward_metadata_capture_cuda_graph( qo_indptr, kv_last_page_len, max_q_len, - None, + kv_indptr[-1].item(), work_metadata=work_metadata, work_info_set=work_info_set, work_indptr=work_indptr, reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -848,36 +917,6 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - if self.use_mla and (_use_mla_ps_kernel and self.num_draft_tokens == None): - - qo_indptr = self.qo_indptr_[: bs + 1] - qo_indptr[1 : bs + 1] = torch.cumsum( - self.cuda_graph_kv_last_page_len[:bs], dim=0 - ) - - max_q_len = 1 - - nhead_kv = 1 - page_size = 1 - - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=True, - ) - elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -902,30 +941,6 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) - # if self.use_mla and self.kv_cache_dtype == fp8_dtype: - if self.use_mla and _use_mla_ps_kernel: - max_q_len = self.num_draft_tokens - - nhead_kv = 1 - page_size = 1 - - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=True, - ) elif forward_mode.is_draft_extend(): seq_lens = seq_lens[:bs] accept_lens = spec_info.accept_length[:bs] @@ -944,29 +959,6 @@ def init_forward_metadata_replay_cuda_graph( self.req_to_token.stride(0), ) - if self.use_mla and _use_mla_ps_kernel: - max_q_len = torch.max(accept_lens).item() - - nhead_kv = 1 - page_size = 1 - - meta = get_mla_metadata_v1( - qo_indptr, - kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=True, - ) else: raise ValueError("Invalid forward mode") @@ -1127,6 +1119,34 @@ def forward_extend( reduce_final_map = self.forward_metadata.reduce_final_map reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr + + max_q_len = self.forward_metadata.max_q_len + nhead_kv = 1 + page_size = 1 + + if layer.layer_id == 0 and _use_mla_ps_kernel: + meta = get_mla_metadata_v1( + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, + ) + mla_decode_fwd( q_input, K_Buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1144,10 +1164,12 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=q_scale, + q_scale=kv_scale, kv_scale=kv_scale, + intra_batch_mode=True, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) - # K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o elif forward_batch.forward_mode.is_draft_extend(): o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) @@ -1170,6 +1192,34 @@ def forward_extend( reduce_final_map = self.forward_metadata.reduce_final_map reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr + + max_q_len = self.forward_metadata.max_q_len + nhead_kv = 1 + page_size = 1 + + if layer.layer_id == 0 and _use_mla_ps_kernel: + meta = get_mla_metadata_v1( + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, + ) + mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), K_Buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1187,10 +1237,12 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=q_scale, + q_scale=kv_scale, kv_scale=kv_scale, + intra_batch_mode=True, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) - # K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim) return o else: raise ValueError( @@ -1251,7 +1303,7 @@ def forward_decode( if self.use_mla: k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - work_meta_data = self.forward_metadata.work_metadata + work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr work_info_set = self.forward_metadata.work_info_set @@ -1259,6 +1311,9 @@ def forward_decode( reduce_final_map = self.forward_metadata.reduce_final_map reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr + if self.kv_cache_dtype == fp8_dtype: # q_input, q_scale = scaled_fp8_quant( # q, @@ -1273,6 +1328,30 @@ def forward_decode( q_scale = None kv_scale = None + nhead_kv = 1 + page_size = 1 + + if layer.layer_id == 0 and _use_mla_ps_kernel: + meta = get_mla_metadata_v1( + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=1, + uni_seqlen_qo=1, + fast_mode=False, + max_split_per_batch=num_kv_splits, + intera_batch_mode=True, + ) + mla_decode_fwd( q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), @@ -1284,7 +1363,7 @@ def forward_decode( self.forward_metadata.max_q_len, layer.scaling, layer.logit_cap, - work_meta_data=work_meta_data, + work_meta_data=work_metadata, work_indptr=work_indptr, work_info_set=work_info_set, reduce_indptr=reduce_indptr, @@ -1292,6 +1371,9 @@ def forward_decode( reduce_partial_map=reduce_partial_map, q_scale=kv_scale, kv_scale=kv_scale, + intra_batch_mode=True, + num_kv_splits=num_kv_splits, + num_kv_splits_indptr=num_kv_splits_indptr, ) # k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: From 2d682c1edf640b94bb874534c918724184c9c842 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Nov 2025 08:33:58 +0000 Subject: [PATCH 14/21] Change the code for adaptive kv-split mla_decode_forward version --- .../srt/layers/attention/aiter_backend.py | 289 ++++++++---------- 1 file changed, 126 insertions(+), 163 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 7c808b2f0905..5b6fee6002a8 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -47,6 +47,15 @@ _use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") +# Persist +# fast_mode=True if _use_mla_ps_kernel else False +# intra_batch_mode=False if _use_mla_ps_kernel else True + +# fake non-ps +fast_mode = False if _use_mla_ps_kernel else False +intra_batch_mode = True if _use_mla_ps_kernel else False + + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() CROSS_ATTENTION = auto() @@ -67,7 +76,7 @@ class ForwardMetadata: reduce_final_map: Optional[torch.Tensor] = None reduce_partial_map: Optional[torch.Tensor] = None num_kv_splits: Optional[int] = None - num_kv_splits_indptr: Optional[torch.Tensor] = None + # num_kv_splits_indptr: Optional[torch.Tensor] = None global_workspace_buffer = None @@ -172,6 +181,15 @@ def __init__( self.enable_dp_attention = is_dp_attention_enabled() + self.max_split_per_batch = 32 if _use_mla_ps_kernel else None + + if self.num_draft_tokens is None and _use_mla_ps_kernel: + self.max_split_per_batch = 64 + + self.kv_scale = None + if self.kv_cache_dtype == fp8_dtype: + self.kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): nhead = self.num_head dtype = self.kv_cache_dtype @@ -190,8 +208,9 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): dtype, dtype, is_sparse=False, - fast_mode=True, - intra_batch_mode=True, + fast_mode=fast_mode, + num_kv_splits=self.max_split_per_batch, + intra_batch_mode=intra_batch_mode, ) # aiter implementation @@ -226,15 +245,44 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): reduce_partial_map, ) - def make_split_kv_buffer(self, bs): - num_kv_splits = 32 + def make_mla_meta_data( + self, + qo_indptr, + kv_indptr, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + max_q_len, + fast_mode, + max_split_per_batch, + intra_batch_mode, + ): - num_kv_splits_indptr = torch.arange( - 0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda" + nhead_kv = 1 + page_size = 1 + meta = get_mla_metadata_v1( + qo_indptr, + kv_indptr, + self.num_head // nhead_kv, + nhead_kv, + True, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=max_q_len, + uni_seqlen_qo=max_q_len, + fast_mode=fast_mode, + max_split_per_batch=max_split_per_batch, + intera_batch_mode=intra_batch_mode, ) - return num_kv_splits, num_kv_splits_indptr - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -253,7 +301,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map = None num_kv_splits = None - num_kv_splits_indptr = None + # num_kv_splits_indptr = None if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: @@ -281,11 +329,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_last_page_len = self.kv_last_page_len[:bs] max_q_len = 1 - if _use_mla_ps_kernel and self.num_draft_tokens == None: - nhead_kv = 1 - page_size = 1 + if _use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch - max_seqlen_qo = 1 ( work_metadata, work_indptr, @@ -293,28 +339,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_indptr, reduce_final_map, reduce_partial_map, - ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) - - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) self.forward_metadata = ForwardMetadata( @@ -331,7 +370,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, ) elif forward_batch.forward_mode.is_draft_extend(): @@ -346,8 +384,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) if _use_mla_ps_kernel: - nhead_kv = 1 - page_size = 1 + + num_kv_splits = self.max_split_per_batch max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) ( @@ -359,26 +397,19 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) - - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_seqlen_qo, - uni_seqlen_qo=max_seqlen_qo, - fast_mode=False, + max_seqlen_qo, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) self.forward_metadata = ForwardMetadata( @@ -396,7 +427,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) else: self.indices_updater_prefill.update( @@ -449,8 +480,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # if self.kv_cache_dtype == fp8_dtype: if _use_mla_ps_kernel: - nhead_kv = 1 - page_size = 1 + + num_kv_splits = self.max_split_per_batch max_seqlen_qo = draft_num ( @@ -462,26 +493,19 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) - - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_seqlen_qo, - uni_seqlen_qo=max_seqlen_qo, - fast_mode=False, + max_seqlen_qo, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) self.forward_metadata = ForwardMetadata( @@ -499,7 +523,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) else: self.indices_updater_prefill.update( @@ -624,7 +648,7 @@ def init_forward_metadata_capture_cuda_graph( ): num_kv_splits = None - num_kv_splits_indptr = None + # num_kv_splits_indptr = None work_metadata = None work_info_set = None @@ -665,29 +689,21 @@ def init_forward_metadata_capture_cuda_graph( max_q_len = 1 if _use_mla_ps_kernel: - nhead_kv = 1 - page_size = 1 - - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + num_kv_splits = self.max_split_per_batch - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, self.work_metadata, self.work_info_set, self.work_indptr, self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) work_metadata = self.work_metadata @@ -712,7 +728,7 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) elif forward_mode.is_target_verify(): @@ -742,29 +758,22 @@ def init_forward_metadata_capture_cuda_graph( # if self.kv_cache_dtype == fp8_dtype: if _use_mla_ps_kernel: - nhead_kv = 1 - page_size = 1 - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + num_kv_splits = self.max_split_per_batch - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, self.work_metadata, self.work_info_set, self.work_indptr, self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) work_metadata = self.work_metadata @@ -789,7 +798,7 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) else: seq_lens_sum = seq_lens.sum().item() @@ -835,29 +844,22 @@ def init_forward_metadata_capture_cuda_graph( max_q_len = num_tokens_per_bs if _use_mla_ps_kernel: - nhead_kv = 1 - page_size = 1 - num_kv_splits, num_kv_splits_indptr = self.make_split_kv_buffer(bs) + num_kv_splits = self.max_split_per_batch - meta = get_mla_metadata_v1( + self.make_mla_meta_data( qo_indptr, kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, self.work_metadata, self.work_info_set, self.work_indptr, self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) work_metadata = self.work_metadata @@ -882,7 +884,7 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -1104,12 +1106,9 @@ def forward_extend( if self.kv_cache_dtype == fp8_dtype: q_input = q.to(fp8_dtype) - q_scale = torch.ones([1], dtype=torch.float, device="cuda") - kv_scale = torch.ones([1], dtype=torch.float, device="cuda") + self.kv_scale.fill_(1) else: q_input = q - q_scale = None - kv_scale = None work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1120,31 +1119,22 @@ def forward_extend( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr - - max_q_len = self.forward_metadata.max_q_len - nhead_kv = 1 - page_size = 1 + # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if layer.layer_id == 0 and _use_mla_ps_kernel: - meta = get_mla_metadata_v1( + self.make_mla_meta_data( self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + self.forward_metadata.max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) mla_decode_fwd( @@ -1164,11 +1154,11 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=kv_scale, - kv_scale=kv_scale, - intra_batch_mode=True, + q_scale=self.kv_scale, + kv_scale=self.kv_scale, + intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) return o elif forward_batch.forward_mode.is_draft_extend(): @@ -1176,13 +1166,8 @@ def forward_extend( if self.kv_cache_dtype == fp8_dtype: q_input = q.to(fp8_dtype) - q_scale = torch.ones([1], dtype=torch.float, device="cuda") - - kv_scale = torch.ones([1], dtype=torch.float, device="cuda") else: q_input = q - q_scale = None - kv_scale = None work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1193,31 +1178,22 @@ def forward_extend( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr - - max_q_len = self.forward_metadata.max_q_len - nhead_kv = 1 - page_size = 1 + # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if layer.layer_id == 0 and _use_mla_ps_kernel: - meta = get_mla_metadata_v1( + self.make_mla_meta_data( self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, + self.forward_metadata.max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) mla_decode_fwd( @@ -1237,11 +1213,11 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=kv_scale, - kv_scale=kv_scale, - intra_batch_mode=True, + q_scale=self.kv_scale, + kv_scale=self.kv_scale, + intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) return o else: @@ -1312,7 +1288,7 @@ def forward_decode( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr + # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if self.kv_cache_dtype == fp8_dtype: # q_input, q_scale = scaled_fp8_quant( @@ -1320,36 +1296,23 @@ def forward_decode( # ) # q_scale = q_scale.to(torch.float) q_input = q.to(fp8_dtype) - q_scale = torch.ones([1], dtype=torch.float, device="cuda") - kv_scale = torch.ones([1], dtype=torch.float, device="cuda") - else: q_input = q - q_scale = None - kv_scale = None - - nhead_kv = 1 - page_size = 1 if layer.layer_id == 0 and _use_mla_ps_kernel: - meta = get_mla_metadata_v1( + self.make_mla_meta_data( self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, - self.num_head // nhead_kv, - nhead_kv, - True, work_metadata, work_info_set, work_indptr, reduce_indptr, reduce_final_map, reduce_partial_map, - kv_granularity=max(page_size, 16), - max_seqlen_qo=1, - uni_seqlen_qo=1, - fast_mode=False, + self.forward_metadata.max_q_len, + fast_mode=fast_mode, max_split_per_batch=num_kv_splits, - intera_batch_mode=True, + intra_batch_mode=intra_batch_mode, ) mla_decode_fwd( @@ -1369,11 +1332,11 @@ def forward_decode( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=kv_scale, - kv_scale=kv_scale, - intra_batch_mode=True, + q_scale=self.kv_scale, + kv_scale=self.kv_scale, + intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - num_kv_splits_indptr=num_kv_splits_indptr, + # num_kv_splits_indptr=num_kv_splits_indptr, ) # k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: From fe0f700667f6d7b411371078b628e20aa8363025 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Dec 2025 03:54:17 +0000 Subject: [PATCH 15/21] Fused q quantization, kv cache store into qk_rope_cat --- .../srt/layers/attention/aiter_backend.py | 74 +++++++++---------- .../srt/layers/quantization/quark/quark.py | 2 + python/sglang/srt/layers/rocm_linear_utils.py | 3 +- python/sglang/srt/models/deepseek_v2.py | 21 +++++- 4 files changed, 58 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 5b6fee6002a8..3909f37eb186 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -4,6 +4,7 @@ end to end attention solution with aiter kernels """ +import logging from dataclasses import dataclass from enum import Enum, auto from typing import TYPE_CHECKING, Optional @@ -44,8 +45,9 @@ from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.utils import get_bool_env_var -_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") +logger = logging.getLogger(__name__) +_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") # Persist # fast_mode=True if _use_mla_ps_kernel else False @@ -97,6 +99,8 @@ def __init__( extend_attention_fwd, ) + self.input_dtype = model_runner.model_config.dtype + self.page_size = model_runner.server_args.page_size self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) @@ -186,10 +190,6 @@ def __init__( if self.num_draft_tokens is None and _use_mla_ps_kernel: self.max_split_per_batch = 64 - self.kv_scale = None - if self.kv_cache_dtype == fp8_dtype: - self.kv_scale = torch.ones([1], dtype=torch.float, device="cuda") - def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): nhead = self.num_head dtype = self.kv_cache_dtype @@ -658,6 +658,8 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map = None reduce_partial_map = None + # log_info_on_rank0(logger, f"[init_forward_metadata_capture_cuda_graph] {forward_mode=}") + if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None @@ -900,6 +902,9 @@ def init_forward_metadata_replay_cuda_graph( spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): + + # log_info_on_rank0(logger, f"[init_forward_metadata_replay_cuda_graph] {forward_mode=}") + if forward_mode.is_decode_or_idle(): kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices @@ -1102,13 +1107,10 @@ def forward_extend( K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim) return o elif forward_batch.forward_mode.is_target_verify(): - o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) - - if self.kv_cache_dtype == fp8_dtype: - q_input = q.to(fp8_dtype) - self.kv_scale.fill_(1) - else: - q_input = q + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1138,7 +1140,7 @@ def forward_extend( ) mla_decode_fwd( - q_input, + q, K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, @@ -1154,20 +1156,18 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=self.kv_scale, - kv_scale=self.kv_scale, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, # num_kv_splits_indptr=num_kv_splits_indptr, ) return o elif forward_batch.forward_mode.is_draft_extend(): - o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) - - if self.kv_cache_dtype == fp8_dtype: - q_input = q.to(fp8_dtype) - else: - q_input = q + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1197,7 +1197,7 @@ def forward_extend( ) mla_decode_fwd( - q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, @@ -1213,8 +1213,8 @@ def forward_extend( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=self.kv_scale, - kv_scale=self.kv_scale, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, # num_kv_splits_indptr=num_kv_splits_indptr, @@ -1267,9 +1267,12 @@ def forward_decode( q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=self.input_dtype, + ) else: - o = torch.empty_like(q) + o = torch.empty_like(q, dtype=torch.bfloat16) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -1277,7 +1280,9 @@ def forward_decode( ) if self.use_mla: - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_buffer = ( + k # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + ) work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1290,15 +1295,6 @@ def forward_decode( num_kv_splits = self.forward_metadata.num_kv_splits # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr - if self.kv_cache_dtype == fp8_dtype: - # q_input, q_scale = scaled_fp8_quant( - # q, - # ) - # q_scale = q_scale.to(torch.float) - q_input = q.to(fp8_dtype) - else: - q_input = q - if layer.layer_id == 0 and _use_mla_ps_kernel: self.make_mla_meta_data( self.forward_metadata.qo_indptr, @@ -1316,7 +1312,7 @@ def forward_decode( ) mla_decode_fwd( - q_input.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), self.forward_metadata.qo_indptr, @@ -1332,8 +1328,8 @@ def forward_decode( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, - q_scale=self.kv_scale, - kv_scale=self.kv_scale, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, # num_kv_splits_indptr=num_kv_splits_indptr, diff --git a/python/sglang/srt/layers/quantization/quark/quark.py b/python/sglang/srt/layers/quantization/quark/quark.py index 37500e6877e0..783f8ea4b118 100644 --- a/python/sglang/srt/layers/quantization/quark/quark.py +++ b/python/sglang/srt/layers/quantization/quark/quark.py @@ -71,6 +71,8 @@ def get_quant_method( ): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() + elif isinstance(layer, RadixAttention): + return QuarkKVCacheMethod(self) return None if isinstance(layer, LinearBase): diff --git a/python/sglang/srt/layers/rocm_linear_utils.py b/python/sglang/srt/layers/rocm_linear_utils.py index ee7dd1f59ed5..6c8a6a367e54 100644 --- a/python/sglang/srt/layers/rocm_linear_utils.py +++ b/python/sglang/srt/layers/rocm_linear_utils.py @@ -1,11 +1,12 @@ import torch +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic from sglang.srt.utils import BumpAllocator -__all__ = ["fused_qk_rope_cat"] +__all__ = ["fused_qk_rope_cat", "fused_qk_rope_cat_and_cache_mla"] def aiter_dsv3_router_gemm( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 584b15bf40a5..4413dbb4dab5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -103,6 +103,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, is_fp8_fnuz, per_tensor_quant_mla_fp8, per_token_group_quant_mla_deep_gemm_masked_fp8, @@ -190,7 +191,7 @@ ) from sglang.srt.layers.rocm_linear_utils import ( aiter_dsv3_router_gemm, - fused_qk_rope_cat, + fused_qk_rope_cat_and_cache_mla, get_dsv3_gemm_output_zero_allocator_size, ) @@ -1964,6 +1965,8 @@ def forward_absorb_core( positions, topk_indices, ): + save_kv_cache = True + if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} if self._fuse_rope_for_trtllm_mla(forward_batch): @@ -1986,16 +1989,29 @@ def forward_absorb_core( if _use_aiter_gfx95: cos = self.rotary_emb.cos_cache sin = self.rotary_emb.sin_cache - q, k = fused_qk_rope_cat( + + kv_cache_dtype = ( + fp8_dtype if self.kv_cache_dtype == "fp8_e4m3" else q_nope_out.dtype + ) + + q, _, _, k = fused_qk_rope_cat_and_cache_mla( q_nope_out, q_pe, k_nope, k_pe, + forward_batch.token_to_kv_pool.get_key_buffer( + self.attn_mqa.layer_id + ), + forward_batch.out_cache_loc, positions, cos, sin, + self.attn_mqa.k_scale, self.rotary_emb.is_neox_style, + q_out_dtype=kv_cache_dtype, ) + + save_kv_cache = False else: q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) @@ -2005,6 +2021,7 @@ def forward_absorb_core( k, k_nope, forward_batch, + save_kv_cache=save_kv_cache, **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) From 4cd758b4a1ceff5b48ec2a5c2fcfabae0de84c5a Mon Sep 17 00:00:00 2001 From: wunhuang Date: Tue, 2 Dec 2025 02:06:09 +0000 Subject: [PATCH 16/21] Fix the run-time error when run the fp8 deepseek-v3 model --- python/sglang/srt/layers/quantization/fp8.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 5db8967c4ac6..ab77d1ab8241 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -171,6 +171,7 @@ def get_quant_method( ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.radix_attention import RadixAttention if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -178,6 +179,8 @@ def get_quant_method( return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return Fp8MoEMethod(self) + elif isinstance(layer, RadixAttention): + return Fp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: From 30c7f41b2341b2b7d695bb7fc1bb8c725a22d80a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Dec 2025 23:58:58 +0000 Subject: [PATCH 17/21] Remove un-neccessary code --- .../srt/layers/attention/aiter_backend.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 3909f37eb186..d93be26da931 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -658,8 +658,6 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map = None reduce_partial_map = None - # log_info_on_rank0(logger, f"[init_forward_metadata_capture_cuda_graph] {forward_mode=}") - if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None @@ -903,8 +901,6 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): - # log_info_on_rank0(logger, f"[init_forward_metadata_replay_cuda_graph] {forward_mode=}") - if forward_mode.is_decode_or_idle(): kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices @@ -1121,7 +1117,6 @@ def forward_extend( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if layer.layer_id == 0 and _use_mla_ps_kernel: self.make_mla_meta_data( @@ -1160,7 +1155,6 @@ def forward_extend( kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - # num_kv_splits_indptr=num_kv_splits_indptr, ) return o elif forward_batch.forward_mode.is_draft_extend(): @@ -1178,7 +1172,6 @@ def forward_extend( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if layer.layer_id == 0 and _use_mla_ps_kernel: self.make_mla_meta_data( @@ -1197,7 +1190,7 @@ def forward_extend( ) mla_decode_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + q, K_Buffer.view(-1, 1, 1, layer.qk_head_dim), o, self.forward_metadata.qo_indptr, @@ -1217,7 +1210,6 @@ def forward_extend( kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - # num_kv_splits_indptr=num_kv_splits_indptr, ) return o else: @@ -1280,9 +1272,7 @@ def forward_decode( ) if self.use_mla: - k_buffer = ( - k # forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - ) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) work_metadata = self.forward_metadata.work_metadata work_indptr = self.forward_metadata.work_indptr @@ -1293,7 +1283,6 @@ def forward_decode( reduce_partial_map = self.forward_metadata.reduce_partial_map num_kv_splits = self.forward_metadata.num_kv_splits - # num_kv_splits_indptr = self.forward_metadata.num_kv_splits_indptr if layer.layer_id == 0 and _use_mla_ps_kernel: self.make_mla_meta_data( @@ -1332,9 +1321,7 @@ def forward_decode( kv_scale=layer.k_scale, intra_batch_mode=intra_batch_mode, num_kv_splits=num_kv_splits, - # num_kv_splits_indptr=num_kv_splits_indptr, ) - # k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: self.logits_soft_cap = layer.logit_cap From a0f9685e47777a6dcc7a2437df2a4e0f86b01590 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 6 Dec 2025 08:09:11 +0000 Subject: [PATCH 18/21] Fix MTP + FP8-KV + DP accuracy issue --- .../srt/layers/attention/aiter_backend.py | 29 +++++++++++++------ .../sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index d93be26da931..6ea7883c4f9e 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -190,10 +190,19 @@ def __init__( if self.num_draft_tokens is None and _use_mla_ps_kernel: self.max_split_per_batch = 64 + self.fix_max_split_per_batch = self.max_split_per_batch + def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): nhead = self.num_head dtype = self.kv_cache_dtype + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.max_split_per_batch = min( + (cu_num + batch_size - 1) // batch_size, self.fix_max_split_per_batch + ) + ( (work_meta_data_size, work_meta_data_type), (work_indptr_size, work_indptr_type), @@ -263,6 +272,8 @@ def make_mla_meta_data( nhead_kv = 1 page_size = 1 + dtype = self.kv_cache_dtype + meta = get_mla_metadata_v1( qo_indptr, kv_indptr, @@ -280,7 +291,9 @@ def make_mla_meta_data( uni_seqlen_qo=max_q_len, fast_mode=fast_mode, max_split_per_batch=max_split_per_batch, - intera_batch_mode=intra_batch_mode, + intra_batch_mode=intra_batch_mode, + dtype_q=dtype, + dtype_kv=dtype, ) def init_forward_metadata(self, forward_batch: ForwardBatch): @@ -330,8 +343,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_q_len = 1 if _use_mla_ps_kernel: - num_kv_splits = self.max_split_per_batch - ( work_metadata, work_indptr, @@ -341,6 +352,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( qo_indptr, kv_indptr, @@ -384,9 +397,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) if _use_mla_ps_kernel: - - num_kv_splits = self.max_split_per_batch - max_seqlen_qo = max(forward_batch.extend_seq_lens_cpu) ( work_metadata, @@ -397,6 +407,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( qo_indptr, kv_indptr, @@ -480,9 +492,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # if self.kv_cache_dtype == fp8_dtype: if _use_mla_ps_kernel: - - num_kv_splits = self.max_split_per_batch - max_seqlen_qo = draft_num ( work_metadata, @@ -493,6 +502,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_partial_map, ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( qo_indptr, kv_indptr, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fd557b78124f..513c8ff279a8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -93,8 +93,8 @@ set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.pooler import EmbeddingPoolerOutput +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager From f7492c6ed4cba05312e3993690b6d7786dc1678f Mon Sep 17 00:00:00 2001 From: root Date: Sun, 7 Dec 2025 03:01:10 +0000 Subject: [PATCH 19/21] Fix the performance regression without using dp --- python/sglang/srt/layers/attention/aiter_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 6ea7883c4f9e..7d0ca774a115 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -196,12 +196,13 @@ def make_mla_decode_meta_data_buffer(self, max_seqlen_qo, batch_size): nhead = self.num_head dtype = self.kv_cache_dtype - gpu = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(gpu) - cu_num = device_properties.multi_processor_count - self.max_split_per_batch = min( - (cu_num + batch_size - 1) // batch_size, self.fix_max_split_per_batch - ) + if self.enable_dp_attention: + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.max_split_per_batch = min( + (cu_num + batch_size - 1) // batch_size, self.fix_max_split_per_batch + ) ( (work_meta_data_size, work_meta_data_type), From 1610fc9b97c13fdd4f016cceebf0694104a59ab3 Mon Sep 17 00:00:00 2001 From: wunhuang Date: Mon, 8 Dec 2025 03:05:34 +0000 Subject: [PATCH 20/21] Fix CI error "Command 'cd /root/.aiter/build/pa_ragged_afefebe1d44cff6c285a28cd6304239b && make build -j1' returned non-zero exit status 2." --- python/sglang/srt/layers/attention/aiter_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 7d0ca774a115..3a858ca8d391 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1276,7 +1276,7 @@ def forward_decode( dtype=self.input_dtype, ) else: - o = torch.empty_like(q, dtype=torch.bfloat16) + o = torch.empty_like(q, dtype=self.input_dtype) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( From 9889990d33a24ab1093d0ed3dd61b18f63996fca Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Dec 2025 00:33:42 +0000 Subject: [PATCH 21/21] Add more comment and set SGLANG_AITER_MLA_PERSIST default True --- Makefile | 49 ------------------- .../srt/layers/attention/aiter_backend.py | 13 ++--- 2 files changed, 7 insertions(+), 55 deletions(-) delete mode 100644 Makefile diff --git a/Makefile b/Makefile deleted file mode 100644 index d6ef1942042e..000000000000 --- a/Makefile +++ /dev/null @@ -1,49 +0,0 @@ -.PHONY: check-deps install-deps format update help - -# Show help for each target -help: - @echo "Available targets:" - @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' - -check-deps: ## Check and install required Python formatting dependencies - @command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort) - @command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black) - -install-deps: ## Install Python formatting tools (isort and black) - pip install isort black - -format: check-deps ## Format modified Python files using isort and black - @echo "Formatting modified Python files..." - git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}' - -FILES_TO_UPDATE = docker/rocm.Dockerfile \ - python/pyproject.toml \ - python/pyproject_other.toml \ - python/sglang/version.py \ - docs/developer_guide/setup_github_runner.md \ - docs/get_started/install.md \ - docs/platforms/amd_gpu.md \ - docs/platforms/ascend_npu.md \ - docs/platforms/cpu_server.md \ - docs/platforms/xpu.md \ - benchmark/deepseek_v3/README.md - -update: ## Update version numbers across project files. Usage: make update - @if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \ - echo "Version required. Usage: make update "; \ - exit 1; \ - fi - @OLD_VERSION=$$(grep "version" python/sglang/version.py | cut -d '"' -f2); \ - NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \ - echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \ - for file in $(FILES_TO_UPDATE); do \ - if [ "$(shell uname)" = "Darwin" ]; then \ - sed -i '' -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \ - else \ - sed -i -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \ - fi \ - done; \ - echo "Version update complete" - -%: - @: diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 3a858ca8d391..d28f8e4f988e 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -29,6 +29,7 @@ from aiter import ( flash_attn_varlen_func, get_mla_metadata_info_v1, + get_mla_metadata_v1, mha_batch_prefill_func, paged_attention_ragged, ) @@ -38,23 +39,21 @@ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) - -from aiter import get_mla_metadata_v1 - from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) -_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST") +# Use aiter mla persist design for fp8-kv cache +_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST", "True") # Persist # fast_mode=True if _use_mla_ps_kernel else False # intra_batch_mode=False if _use_mla_ps_kernel else True -# fake non-ps -fast_mode = False if _use_mla_ps_kernel else False +# fake non-ps, intra_batch_mode needs to be True for non-ps-mode +fast_mode = False intra_batch_mode = True if _use_mla_ps_kernel else False @@ -1235,6 +1234,7 @@ def forward_extend( bs0 = forward_batch.batch_size + 1 + # TODO kkhuang-amd need to remove it when mha_batch_prefill_func support fp8-kv if self.kv_cache_dtype == fp8_dtype: dtype = q.dtype k_cache = k_cache.to(dtype) @@ -1341,6 +1341,7 @@ def forward_decode( layer.layer_id ) + # TODO kkhuang-amd need to remove it when paged_attention_ragged support fp8-kv if self.kv_cache_dtype == fp8_dtype: dtype = q.dtype