Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item()

sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
cu_seqlens_q = torch.nn.functional.pad(
Expand Down Expand Up @@ -985,25 +983,6 @@ def forward_extend(
)
else:
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if self.padded_q_buffer is not None:
# Use pre-allocated buffer for CUDA graph compatibility
padded_q = self.padded_q_buffer[
:bs, : metadata.max_seq_len_q, :, :
].to(dtype=q.dtype)
else:
# Dynamic allocation for non-CUDA graph mode
padded_q = torch.zeros(
bs,
metadata.max_seq_len_q,
layer.tp_q_head_num,
layer.head_dim,
dtype=q.dtype,
device=q.device,
)
q = self.pad_draft_extend_query(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the pad_draft_extend_query call, maybe we should already delete the relate code.

q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
)

# TODO may use `mla_rope_quantize_fp8` fusion
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
Expand All @@ -1022,15 +1001,6 @@ def forward_extend(
bmm1_scale=bmm1_scale,
)

# Reshape output directly without slicing

if forward_batch.forward_mode.is_draft_extend(include_v2=True):
raw_out = self.unpad_draft_extend_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

raw_out,
metadata.cu_seqlens_q,
metadata.seq_lens_q,
metadata.sum_seq_lens_q,
)
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output

Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
set_is_extend_in_batch,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
from sglang.srt.utils.common import ceil_align

if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -731,9 +732,7 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
for i in range(sync_group_size):
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
global_num_tokens[i] = (
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_tp_size)

dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
self.is_extend_in_batch, global_num_tokens
Expand Down Expand Up @@ -764,7 +763,9 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):

bs = self.batch_size

if self.forward_mode.is_decode():
if self.forward_mode.is_decode() or self.forward_mode.is_draft_extend(
include_v2=True
):
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _draft_extend_for_decode(
)
else:
draft_logits_output, _ = self.draft_runner.forward(
forward_batch, skip_attn_backend_init=True
forward_batch, skip_attn_backend_init=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip attn backend lead to the flashinfer backend cudagraph capture failure.

)

# Reorganize the spec info for the next batch
Expand Down
78 changes: 74 additions & 4 deletions test/manual/test_eagle_infer_beta_dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
Expand All @@ -15,14 +18,12 @@
write_github_step_summary,
)

FULL_DEEPSEEK_V3_FP4_MODEL_PATH = "nvidia/DeepSeek-V3-0324-FP4"


class TestEagleDPAttnServerBase(CustomTestCase):
class TestEagleDPAttnServerLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_ENABLE_SPEC_V2"] = "1"
cls.model = FULL_DEEPSEEK_V3_FP4_MODEL_PATH
cls.model = DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp-size",
Expand Down Expand Up @@ -93,5 +94,74 @@ def test_a_gsm8k(
self.assertGreater(avg_spec_accept_length, 2.04)


class TestEagleDPAttnServerSmall(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_ENABLE_SPEC_V2"] = "1"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp-size",
"2",
"--dp-size",
"2",
"--enable-dp-attention",
"--speculative-draft-model-path",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if "SGLANG_ENABLE_SPEC_V2" in os.environ:
del os.environ["SGLANG_ENABLE_SPEC_V2"]

def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (deepseek-v3-fp4 mtp)\n"
f'{metrics["accuracy"]=:.3f}\n'
f"{avg_spec_accept_length=:.2f}\n"
)
self.assertGreater(metrics["accuracy"], 0.94)
self.assertGreater(avg_spec_accept_length, 2.04)


if __name__ == "__main__":
unittest.main()
Loading