From 5f829c4b83bebb1eb63625666e21ffc749aba566 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 4 Dec 2025 13:17:38 +0800 Subject: [PATCH 1/5] add small model --- .../test_eagle_infer_beta_dp_attention.py | 78 ++++++++++++++++++- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/test/manual/test_eagle_infer_beta_dp_attention.py b/test/manual/test_eagle_infer_beta_dp_attention.py index 382196a18fd5..97e35f7e11e9 100644 --- a/test/manual/test_eagle_infer_beta_dp_attention.py +++ b/test/manual/test_eagle_infer_beta_dp_attention.py @@ -7,6 +7,9 @@ from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( + DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -15,14 +18,12 @@ write_github_step_summary, ) -FULL_DEEPSEEK_V3_FP4_MODEL_PATH = "nvidia/DeepSeek-V3-0324-FP4" - -class TestEagleDPAttnServerBase(CustomTestCase): +class TestEagleDPAttnServerLarge(CustomTestCase): @classmethod def setUpClass(cls): os.environ["SGLANG_ENABLE_SPEC_V2"] = "1" - cls.model = FULL_DEEPSEEK_V3_FP4_MODEL_PATH + cls.model = DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ "--tp-size", @@ -93,5 +94,74 @@ def test_a_gsm8k( self.assertGreater(avg_spec_accept_length, 2.04) +class TestEagleDPAttnServerSmall(CustomTestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_ENABLE_SPEC_V2"] = "1" + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--tp-size", + "2", + "--dp-size", + "2", + "--enable-dp-attention", + "--speculative-draft-model-path", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + if "SGLANG_ENABLE_SPEC_V2" in os.environ: + del os.environ["SGLANG_ENABLE_SPEC_V2"] + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (deepseek-v3-fp4 mtp)\n" + f'{metrics["accuracy"]=:.3f}\n' + f"{avg_spec_accept_length=:.2f}\n" + ) + self.assertGreater(metrics["accuracy"], 0.94) + self.assertGreater(avg_spec_accept_length, 2.04) + + if __name__ == "__main__": unittest.main() From 2590bfc5be7e056e8bd5c1ad30d6c678153a2b12 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 4 Dec 2025 13:18:20 +0800 Subject: [PATCH 2/5] remove padding --- .../layers/attention/trtllm_mla_backend.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 0c8f832ed26b..2824ff8a6761 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -594,8 +594,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens = seq_lens + self.num_draft_tokens self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32) elif forward_batch.forward_mode.is_draft_extend(include_v2=True): - max_seq = forward_batch.seq_lens_cpu.max().item() - sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) cu_seqlens_q = torch.nn.functional.pad( @@ -985,25 +983,6 @@ def forward_extend( ) else: max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q - # Check if we're in CUDA graph mode (buffers are pre-allocated) - if self.padded_q_buffer is not None: - # Use pre-allocated buffer for CUDA graph compatibility - padded_q = self.padded_q_buffer[ - :bs, : metadata.max_seq_len_q, :, : - ].to(dtype=q.dtype) - else: - # Dynamic allocation for non-CUDA graph mode - padded_q = torch.zeros( - bs, - metadata.max_seq_len_q, - layer.tp_q_head_num, - layer.head_dim, - dtype=q.dtype, - device=q.device, - ) - q = self.pad_draft_extend_query( - q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q - ) # TODO may use `mla_rope_quantize_fp8` fusion q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) @@ -1022,15 +1001,6 @@ def forward_extend( bmm1_scale=bmm1_scale, ) - # Reshape output directly without slicing - - if forward_batch.forward_mode.is_draft_extend(include_v2=True): - raw_out = self.unpad_draft_extend_output( - raw_out, - metadata.cu_seqlens_q, - metadata.seq_lens_q, - metadata.sum_seq_lens_q, - ) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output From 99126f9cab112ef59a84a38c245148d287fac75e Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 4 Dec 2025 13:18:45 +0800 Subject: [PATCH 3/5] tiny adjust --- python/sglang/srt/model_executor/forward_batch_info.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8baddc56f0c5..cd08092af33c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -52,6 +52,7 @@ set_is_extend_in_batch, ) from sglang.srt.utils import get_compiler_backend, is_npu, support_triton +from sglang.srt.utils.common import ceil_align if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -731,9 +732,7 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): for i in range(sync_group_size): # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob - global_num_tokens[i] = ( - (global_num_tokens[i] - 1) // attn_tp_size + 1 - ) * attn_tp_size + global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_tp_size) dp_padding_mode = DpPaddingMode.get_dp_padding_mode( self.is_extend_in_batch, global_num_tokens From 8bf2cb3b8dd8b2df70ddae40942cfe900fc995cf Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 4 Dec 2025 13:19:38 +0800 Subject: [PATCH 4/5] skip init attn backend = false --- python/sglang/srt/speculative/eagle_worker_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 663c4f3415fb..e344fb99f527 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -512,7 +512,7 @@ def _draft_extend_for_decode( ) else: draft_logits_output, _ = self.draft_runner.forward( - forward_batch, skip_attn_backend_init=True + forward_batch, skip_attn_backend_init=False ) # Reorganize the spec info for the next batch From 84c6b806200e000b50c87c4b3ea1b34b8a1f4a87 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 4 Dec 2025 13:26:10 +0800 Subject: [PATCH 5/5] fix --- python/sglang/srt/model_executor/forward_batch_info.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index cd08092af33c..748d4a145ebf 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -763,7 +763,9 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): bs = self.batch_size - if self.forward_mode.is_decode(): + if self.forward_mode.is_decode() or self.forward_mode.is_draft_extend( + include_v2=True + ): if self.is_extend_in_batch and dp_padding_mode.is_max_len(): setattr(self, "_original_forward_mode", self.forward_mode) self.forward_mode = ForwardMode.EXTEND