Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,9 @@ def merge_batch(self, other: "ScheduleBatch"):
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)

def get_model_worker_batch(self) -> ModelWorkerBatch:
def get_model_worker_batch(
self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
Expand All @@ -1642,16 +1644,20 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:

# Create seq_lens_cpu when needed
if (
(
global_server_args_dict["attention_backend"] == "fa3"
or (
global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer"
)
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = self.seq_lens.cpu()
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
else:
seq_lens_cpu = None

Expand Down
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,10 +1574,9 @@ def run_batch(
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
bs = batch.batch_size()
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens

if self.pp_group.is_last_rank:
Expand Down
23 changes: 13 additions & 10 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def get_is_capture_mode():
return is_capture_mode


@contextmanager
def model_capture_mode():
global is_capture_mode
is_capture_mode = True

yield

is_capture_mode = False


def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
Expand Down Expand Up @@ -291,22 +301,13 @@ def __init__(self, model_runner: ModelRunner):

# Capture
try:
with self.model_capture_mode():
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)

@contextmanager
def model_capture_mode(self):
global is_capture_mode
is_capture_mode = True

yield

is_capture_mode = False

def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
Expand Down Expand Up @@ -650,6 +651,8 @@ def get_spec_info(self, num_tokens: int):
topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=None,
seq_lens_cpu=None,
)

return spec_info
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,13 +1013,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str,
choices=[
"aiter",
"flashinfer",
"triton",
"torch_native",
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"cutlass_mla",
"intel_amx",
"torch_native",
"triton",
],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
set_torch_compile_config,
)
Expand Down Expand Up @@ -80,7 +81,8 @@ def __init__(self, eagle_worker: EAGLEWorker):

# Capture
try:
self.capture()
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LogitsProcessorOutput,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
set_torch_compile_config,
)
Expand All @@ -19,7 +20,7 @@
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk

if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
Expand All @@ -37,6 +38,7 @@ def __init__(self, eagle_worker: EAGLEWorker):
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.topk = model_runner.server_args.speculative_eagle_topk
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1

Expand Down Expand Up @@ -87,7 +89,8 @@ def __init__(self, eagle_worker: EAGLEWorker):

# Capture
try:
self.capture()
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
Expand Down Expand Up @@ -170,6 +173,8 @@ def run_once():
forward_batch.positions,
forward_batch,
)
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)

forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
Expand Down Expand Up @@ -198,7 +203,7 @@ def replay(self, forward_batch: ForwardBatch):

index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(1)
self.accept_length.fill_(1)
self.out_cache_loc.zero_()
Expand Down Expand Up @@ -238,8 +243,11 @@ def replay(self, forward_batch: ForwardBatch):
out = self.output_buffers[bs]
if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
out_copy = out
out = LogitsProcessorOutput(
next_token_logits=out.next_token_logits[:raw_bs],
hidden_states=out.hidden_states[:raw_bs],
)
out.topk_p = out_copy.topk_p[:raw_bs]
out.topk_index = out_copy.topk_index[:raw_bs]
return out
Loading
Loading