Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions python/sglang/srt/layers/attention/tbo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,15 @@ def _init_forward_metadata_cuda_graph_children(
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.forward_batch_info import ForwardMode

if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs

forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=num_tokens,
extend_lens=None,
)
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
tbo_split_seq_index, tbo_split_token_index = (
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
cuda_graph_num_tokens=num_tokens,
)
)

num_tokens_child_left = tbo_split_token_index
Expand Down
16 changes: 10 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerUtils,
TboCudaGraphRunnerPlugin,
TboForwardBatchPreparer,
)
from sglang.srt.utils import (
Expand Down Expand Up @@ -256,6 +256,7 @@ def __init__(self, model_runner: ModelRunner):
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
self.tbo_plugin = TboCudaGraphRunnerPlugin()

# pipeline parallelism
if self.pp_size > 1:
Expand Down Expand Up @@ -481,12 +482,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded,
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
self, num_tokens
),
global_forward_mode=self.capture_forward_mode,
Comment on lines 484 to 487
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The original code computes tbo_split_seq_index here. Now that this logic is moved to TboCudaGraphRunnerPlugin, is it correct to remove it from here? Please confirm that this change aligns with the intended behavior.

)
TboForwardBatchPreparer.prepare(forward_batch)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)

if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
Expand Down Expand Up @@ -581,7 +579,13 @@ def replay_prepare(
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
self.num_token_non_padded[...] = len(forward_batch.input_ids)
num_token_non_padded = len(forward_batch.input_ids)
self.num_token_non_padded[...] = num_token_non_padded
self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=num_token_non_padded,
)
Comment on lines +582 to +588
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This section introduces a call to self.tbo_plugin.replay_prepare. It's important to ensure that the parameters passed to this function (forward_mode, bs, num_token_non_padded) are the correct and expected values for the TBO logic during replay preparation. Double-check these parameters to prevent any unexpected behavior.

if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
Expand Down
61 changes: 45 additions & 16 deletions python/sglang/srt/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,54 @@ def compute_split_token_index(
raise NotImplementedError


def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens,
extend_lens=None,
)
tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
)
return tbo_split_seq_index, tbo_split_token_index


# -------------------------------- Preparation ---------------------------------------


class TboCudaGraphRunnerUtils:
@staticmethod
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int):
if that.model_runner.server_args.enable_two_batch_overlap:
tbo_split_seq_index = compute_split_seq_index(
forward_mode=that.capture_forward_mode,
num_tokens=num_tokens,
extend_lens=None,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert (
tbo_split_seq_index is not None
), f"{that.capture_forward_mode=} {num_tokens=}"
else:
tbo_split_seq_index = None
return tbo_split_seq_index
class TboCudaGraphRunnerPlugin:
def __init__(self):
pass # TODO add logic here
Comment on lines +112 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method for the new TboCudaGraphRunnerPlugin currently contains a pass statement with a TODO add logic here comment.

Could you clarify if there's any initialization logic planned for this plugin in the near future, or if this TODO is more of a placeholder for potential future state management? If no initialization is currently needed, this is fine, but understanding the intent would be helpful.


def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
return

batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens,
extend_lens=None,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"

TboForwardBatchPreparer.prepare(batch)

def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
):
if not global_server_args_dict["enable_two_batch_overlap"]:
return

pass # TODO add logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The replay_prepare method in TboCudaGraphRunnerPlugin is newly introduced as a hook in CudaGraphRunner.replay_prepare. It currently contains a pass statement with a TODO add logic here comment after checking enable_two_batch_overlap.

What is the intended purpose of this method? Is there specific TBO-related preparation logic that needs to be executed during CUDA graph replay that is planned to be added here? If so, this PR might be setting up the structure for it. Clarifying the scope of this TODO would be beneficial.



class TboDPAttentionPreparer:
Expand Down
Loading