Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 203 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import logging
import os
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
import tqdm
Expand Down Expand Up @@ -58,9 +59,10 @@
ForwardBatch,
ForwardMode,
PPProxyTensors,
compute_local_num_token_non_padded,
enable_num_token_non_padded,
)
from sglang.srt.model_executor.input_buffers import GraphInputBuffers
from sglang.srt.model_executor.input_buffers import ForwardInputBuffers
from sglang.srt.multiplex.pdmux_context import get_current_stream_idx, get_stream_groups
from sglang.srt.utils import (
empty_context,
Expand Down Expand Up @@ -90,6 +92,200 @@
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner


@dataclass
class DecodeInputBuffers(ForwardInputBuffers):

input_ids: torch.Tensor
input_embeds: torch.Tensor
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
out_cache_loc: torch.Tensor
positions: torch.Tensor
mrope_positions: torch.Tensor
num_token_non_padded: torch.Tensor
custom_mask: torch.Tensor
next_token_logits_buffer: torch.Tensor
mamba_track_indices: Optional[torch.Tensor]
mamba_track_mask: Optional[torch.Tensor]
global_num_tokens_gpu: torch.Tensor
global_num_tokens_for_logprob_gpu: torch.Tensor
encoder_lens: Optional[torch.Tensor]
pp_proxy_tensors: Optional[Dict[str, torch.Tensor]]

@classmethod
def create(
cls,
*,
device: torch.device,
max_bs: int,
max_num_token: int,
hidden_size: int,
vocab_size: int,
dtype: torch.dtype,
dp_size: int,
pp_size: int,
is_encoder_decoder: bool,
require_mlp_tp_gather: bool,
seq_len_fill_value: int,
encoder_len_fill_value: int,
num_tokens_per_bs: int,
cache_loc_dtype: torch.dtype,
enable_mamba_track: bool,
) -> "DecodeInputBuffers":
with torch.device(device):
input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
input_embeds = torch.zeros((max_num_token, hidden_size), dtype=dtype)
req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32)
seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
out_cache_loc = torch.zeros((max_num_token,), dtype=cache_loc_dtype)
positions = torch.zeros((max_num_token,), dtype=torch.int64)
mrope_positions = torch.zeros((3, max_num_token), dtype=torch.int64)
num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
custom_mask = torch.ones(
(max_bs * seq_len_fill_value + max_num_token) * num_tokens_per_bs,
dtype=torch.bool,
)
next_token_logits_buffer = torch.zeros(
(max_num_token, vocab_size),
dtype=torch.float,
)
mamba_track_indices = (
torch.zeros((max_bs,), dtype=torch.int64)
if enable_mamba_track
else None
)
mamba_track_mask = (
torch.zeros((max_bs,), dtype=torch.bool) if enable_mamba_track else None
)

if pp_size > 1:
pp_proxy_tensors = {
"hidden_states": torch.zeros((max_bs, hidden_size), dtype=dtype),
"residual": torch.zeros((max_bs, hidden_size), dtype=dtype),
}
else:
pp_proxy_tensors = None

if is_encoder_decoder:
encoder_lens = torch.full(
(max_bs,), encoder_len_fill_value, dtype=torch.int32
)
else:
encoder_lens = None

if require_mlp_tp_gather:
global_num_tokens_gpu = torch.zeros((dp_size,), dtype=torch.int32)
global_num_tokens_for_logprob_gpu = torch.zeros(
(dp_size,), dtype=torch.int32
)
else:
global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
global_num_tokens_for_logprob_gpu = torch.zeros((1,), dtype=torch.int32)

# Keep seq_lens_cpu as a true CPU tensor, like the old implementation.
seq_lens_cpu = torch.full(
(max_bs,),
seq_len_fill_value,
dtype=torch.int32,
device="cpu",
)

return cls(
input_ids=input_ids,
input_embeds=input_embeds,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
out_cache_loc=out_cache_loc,
positions=positions,
mrope_positions=mrope_positions,
num_token_non_padded=num_token_non_padded,
custom_mask=custom_mask,
next_token_logits_buffer=next_token_logits_buffer,
mamba_track_indices=mamba_track_indices,
mamba_track_mask=mamba_track_mask,
encoder_lens=encoder_lens,
global_num_tokens_gpu=global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu,
pp_proxy_tensors=pp_proxy_tensors,
)

def populate_from_forward_batch(
self,
*,
forward_batch: ForwardBatch,
raw_bs: int,
raw_num_token: int,
bs: int,
seq_len_fill_value: int,
require_gathered_buffer: bool,
num_tokens_per_bs: int,
nsa_enable_prefill_cp: bool,
enable_num_token_non_padded_flag: bool,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
if bs != raw_bs:
self.seq_lens.fill_(seq_len_fill_value)
self.out_cache_loc.zero_()
if self.mamba_track_indices is not None:
self.mamba_track_indices.zero_()
if self.mamba_track_mask is not None:
self.mamba_track_mask.fill_(False)

# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)

if (
self.mamba_track_indices is not None
and forward_batch.mamba_track_indices is not None
):
self.mamba_track_indices[:raw_bs].copy_(forward_batch.mamba_track_indices)
if (
self.mamba_track_mask is not None
and forward_batch.mamba_track_mask is not None
):
self.mamba_track_mask[:raw_bs].copy_(forward_batch.mamba_track_mask)

if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)

if self.encoder_lens is not None and forward_batch.encoder_lens is not None:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)

if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)

if require_gathered_buffer:
self.global_num_tokens_gpu.fill_(bs * num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * num_tokens_per_bs)

if enable_num_token_non_padded_flag:
if require_gathered_buffer and not nsa_enable_prefill_cp:
num_tokens_per_dp = bs * num_tokens_per_bs
local = compute_local_num_token_non_padded(
global_num_token_non_padded=forward_batch.num_token_non_padded,
num_tokens_per_dp=num_tokens_per_dp,
)
self.num_token_non_padded.copy_(local)
else:
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)

# Pipeline-parallel proxy tensors.
if pp_proxy_tensors is not None and self.pp_proxy_tensors is not None:
for key, buf in self.pp_proxy_tensors.items():
src = pp_proxy_tensors.tensors[key]
dim = src.shape[0]
buf[:dim].copy_(src)


# Detect whether the current forward pass is in capture mode
is_capture_mode = False

Expand Down Expand Up @@ -337,7 +533,7 @@ def __init__(self, model_runner: ModelRunner):

if self.require_gathered_buffer:
assert self.require_mlp_tp_gather or self.require_attn_tp_gather
self.buffers: GraphInputBuffers = GraphInputBuffers.create(
self.buffers: DecodeInputBuffers = DecodeInputBuffers.create(
device=self.device,
max_bs=self.max_bs,
max_num_token=self.max_num_token,
Expand All @@ -354,6 +550,7 @@ def __init__(self, model_runner: ModelRunner):
cache_loc_dtype=self._cache_loc_dtype(),
enable_mamba_track=enable_mamba_track,
)
self.buffers.share_buffers()

self.tbo_plugin = TboCudaGraphRunnerPlugin()

Expand Down Expand Up @@ -556,7 +753,7 @@ def _create_device_graph(self):
def capture_one_batch_size(
self, bs: int, forward: Callable, stream_idx: Optional[int] = None
):
buffers: GraphInputBuffers = self.buffers
buffers: DecodeInputBuffers = self.buffers
graph = self._create_device_graph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
Expand Down Expand Up @@ -798,7 +995,7 @@ def replay_prepare(
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]

seq_lens_cpu = buffers.populate_from_forward_batch(
buffers.populate_from_forward_batch(
forward_batch=forward_batch,
raw_bs=raw_bs,
raw_num_token=raw_num_token,
Expand Down Expand Up @@ -835,7 +1032,7 @@ def replay_prepare(
buffers.encoder_lens[:bs] if self.is_encoder_decoder else None,
self.capture_forward_mode,
forward_batch.spec_info,
seq_lens_cpu=seq_lens_cpu,
seq_lens_cpu=buffers.seq_lens_cpu[:bs],
)

# Store fields
Expand Down
Loading
Loading