Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,29 @@ def init_cuda_graph_state(self, max_bs: int):
),
}

self.draft_extend_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.zeros(
max_bs + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}

if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
Expand Down Expand Up @@ -1508,6 +1531,32 @@ def init_forward_metadata_capture_cuda_graph(

self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))

num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.max_seq_len_k = seq_lens.max().item()

metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=device,
)

metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.draft_extend_metadata["page_table"][
req_pool_indices, :
]

self.draft_extend_metadata[bs] = metadata

if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
Expand Down Expand Up @@ -1732,6 +1781,23 @@ def init_forward_metadata_replay_cuda_graph(
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))

metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)

if encoder_lens is not None:
# Only support encoder size 1 for now
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from __future__ import annotations

import bisect
from typing import TYPE_CHECKING, Callable

import torch

from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner,
LogitsProcessorOutput,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput

if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker


class EAGLEDraftExtendCudaGraphRunner:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW when enabling extended CUDA Graph, we should also adjust mem_fraction_static to avoid running out of memory. For example, set it to 5/4/8.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python3 -m sglang.launch_server --model /dev/shm/DeepSeek-V3-0324 --tp 8 --trust-remote-code --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --speculative-algorithm EAGLE

lm_eval --model local-chat-completions --model_args model=/dev/shm/DeepSeek-V3-0324,base_url=http://127.0.0.1:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks gsm8k --batch_size 128 --apply_chat_template --num_fewshot 8

This will be OOM.

def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1

# Attention backend
self.num_tokens_per_bs = self.speculative_num_steps + 1
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs

self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token
)
self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
)
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)

if self.enable_torch_compile:
set_torch_compile_config()

# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)

if self.eagle_worker.speculative_algorithm.is_eagle3():
self.hidden_states = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size * 3,
),
dtype=self.model_runner.dtype,
)
else:
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)

self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32)

# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture CUDA graph failed: {e}\n"
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)

def can_run(self, forward_batch: ForwardBatch):
batch_size = forward_batch.seq_lens.numel()

is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)

return is_bs_supported

def capture(self):
CudaGraphRunner.capture(self)

def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs

# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
extend_seq_lens = self.extend_seq_lens[:bs]
accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]

spec_info = EagleDraftInput(
hidden_states=hidden_states,
accept_length=accept_length,
)
spec_info.positions = None

# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DRAFT_EXTEND,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
attn_backend=self.eagle_worker.draft_extend_attn_backend,
extend_seq_lens=extend_seq_lens,
padded_static_len=self.padded_static_len,
)

self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_capture_cuda_graph(
bs=bs,
num_tokens=num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=spec_info,
)

# Run and capture
def run_once():
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states

ret = self.eagle_worker.draft_model_runner.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
)

forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret

for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

run_once()

with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()

set_global_graph_memory_pool(graph.pool())
return graph, out

def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.seq_lens.numel()
num_tokens = forward_batch.input_ids.numel()
assert raw_bs * self.num_tokens_per_bs == num_tokens

index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.accept_length.fill_(1)
self.out_cache_loc.zero_()

# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens[:raw_bs])
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens[:raw_bs])
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
self.accept_length[:raw_bs].copy_(
forward_batch.spec_info.accept_length[:raw_bs]
)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)

if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)

forward_batch.spec_info.positions = None
if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:bs]

self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
seq_lens_sum=sum(forward_batch.seq_lens),
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
)

# Replay
self.graphs[bs].replay()
out = self.output_buffers[bs]
if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
out = LogitsProcessorOutput(
next_token_logits=out.next_token_logits[:raw_bs],
hidden_states=out.hidden_states[:raw_bs],
)
return out
45 changes: 45 additions & 0 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
speculative_num_steps: int,
pad_input: bool = False,
):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu
Expand Down Expand Up @@ -111,6 +112,50 @@ def prepare_extend_after_decode(
batch.input_ids = self.verified_id
self.verified_id = new_verified_id

if pad_input:
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len

padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(
padded_len, device=self.positions.device
)
new_positions = torch.cat([self.positions, position_padding])

# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)

# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])

batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc

def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
Expand Down
Loading
Loading