diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 1c254c4fa50..b4cf99dff8f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -358,6 +358,35 @@ def init_forward_metadata_capture_cuda_graph( ) self.prefill_cuda_graph_metadata[bs] = prefill_wrappers self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False) + elif forward_mode.is_draft_extend(): + prefill_wrappers = [] + for i in range(self.num_wrappers): + prefill_wrappers.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + backend="fa2", + use_cuda_graph=True, + qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], + paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], + ) + ) + + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=prefill_wrappers, + use_ragged=False, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = prefill_wrappers + self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -392,6 +421,17 @@ def init_forward_metadata_replay_cuda_graph( encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, spec_info=spec_info, ) + elif forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) else: raise ValueError("Invalid forward mode") diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index a6a255b3bb6..57ad6fc300f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -278,6 +278,28 @@ def init_forward_metadata_capture_cuda_graph( ) self.prefill_cuda_graph_metadata[bs] = verify_wrapper self.forward_metadata = PrefillMetadata(verify_wrapper, False) + elif forward_mode.is_draft_extend(): + draft_extend_wrapper = BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.cuda_graph_qo_indptr[: bs + 1], + kv_indptr=self.cuda_graph_kv_indptr[: bs + 1], + kv_indices=self.cuda_graph_kv_indices, + kv_len_arr=self.cuda_graph_kv_lens[:bs], + backend="auto", + ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=draft_extend_wrapper, + use_ragged=False, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper + self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -325,6 +347,16 @@ def init_forward_metadata_replay_cuda_graph( use_ragged=False, spec_info=spec_info, ) + elif forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + spec_info=spec_info, + ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 3fd42737599..e817196e6c3 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -80,7 +80,9 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) - self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32) + self.accept_length = ( + torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs + ) # Capture try: diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 1c78714b70f..a9193150b72 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -156,6 +156,7 @@ def init_attention_backend(self): if self.server_args.attention_backend == "flashinfer": if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, FlashInferMultiStepDraftBackend, ) @@ -164,8 +165,13 @@ def init_attention_backend(self): self.topk, self.speculative_num_steps, ) + self.draft_extend_attn_backend = FlashInferAttnBackend( + self.draft_model_runner, + skip_prefill=False, + ) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, FlashInferMLAMultiStepDraftBackend, ) @@ -174,7 +180,10 @@ def init_attention_backend(self): self.topk, self.speculative_num_steps, ) - self.draft_extend_attn_backend = None + self.draft_extend_attn_backend = FlashInferMLAAttnBackend( + self.draft_model_runner, + skip_prefill=False, + ) self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "triton": diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 23c7ebdfe1b..7662ca33336 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -19,6 +19,7 @@ from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -602,6 +603,7 @@ def setUpClass(cls): "fa3", ], ) + cls.accept_len_threshold = 1.50 @classmethod def tearDownClass(cls): @@ -636,7 +638,89 @@ def test_one_batch_accept_length(self): acc_length = 1.0 print(f"{acc_length=}") - self.assertGreater(acc_length, 1.50) + self.assertGreater(acc_length, self.accept_len_threshold) + + +class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 1, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 2, + "--max-running-requests", + 4, + "--attention-backend", + "flashinfer", + ], + ) + cls.accept_len_threshold = 1.50 + + +class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 1, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 2, + "--max-running-requests", + 4, + "--attention-backend", + "triton", + ], + ) + cls.accept_len_threshold = 1.50 + + +class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + 1, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 2, + "--max-running-requests", + 4, + "--attention-backend", + "flashinfer", + ], + ) + cls.accept_len_threshold = 1.85 if __name__ == "__main__":