Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,35 @@ def init_forward_metadata_capture_cuda_graph(
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
elif forward_mode.is_draft_extend():
prefill_wrappers = []
for i in range(self.num_wrappers):
prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
use_cuda_graph=True,
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
)
)

seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=prefill_wrappers,
use_ragged=False,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")

Expand Down Expand Up @@ -392,6 +421,17 @@ def init_forward_metadata_replay_cuda_graph(
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
elif forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")

Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,28 @@ def init_forward_metadata_capture_cuda_graph(
)
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
elif forward_mode.is_draft_extend():
draft_extend_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=draft_extend_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper
self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")

Expand Down Expand Up @@ -325,6 +347,16 @@ def init_forward_metadata_replay_cuda_graph(
use_ragged=False,
spec_info=spec_info,
)
elif forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(self, eagle_worker: EAGLEWorker):

self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32)
self.accept_length = (
torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs
)

# Capture
try:
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def init_attention_backend(self):
if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
FlashInferMultiStepDraftBackend,
)

Expand All @@ -164,8 +165,13 @@ def init_attention_backend(self):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashInferAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)

Expand All @@ -174,7 +180,10 @@ def init_attention_backend(self):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
Expand Down
86 changes: 85 additions & 1 deletion test/srt/test_eagle_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
Expand Down Expand Up @@ -602,6 +603,7 @@ def setUpClass(cls):
"fa3",
],
)
cls.accept_len_threshold = 1.50

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -636,7 +638,89 @@ def test_one_batch_accept_length(self):
acc_length = 1.0

print(f"{acc_length=}")
self.assertGreater(acc_length, 1.50)
self.assertGreater(acc_length, self.accept_len_threshold)


class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.50


class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"triton",
],
)
cls.accept_len_threshold = 1.50


class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.85


if __name__ == "__main__":
Expand Down
Loading