-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Minor refactor two-batch overlap #6682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ | |
| ) | ||
| from sglang.srt.patch_torch import monkey_patch_torch_compile | ||
| from sglang.srt.two_batch_overlap import ( | ||
| TboCudaGraphRunnerUtils, | ||
| TboCudaGraphRunnerPlugin, | ||
| TboForwardBatchPreparer, | ||
| ) | ||
| from sglang.srt.utils import ( | ||
|
|
@@ -256,6 +256,7 @@ def __init__(self, model_runner: ModelRunner): | |
| self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) | ||
| self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) | ||
| self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) | ||
| self.tbo_plugin = TboCudaGraphRunnerPlugin() | ||
|
|
||
| # pipeline parallelism | ||
| if self.pp_size > 1: | ||
|
|
@@ -481,12 +482,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): | |
| capture_hidden_mode=self.capture_hidden_mode, | ||
| lora_paths=lora_paths, | ||
| num_token_non_padded=self.num_token_non_padded, | ||
| tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index( | ||
| self, num_tokens | ||
| ), | ||
| global_forward_mode=self.capture_forward_mode, | ||
| ) | ||
| TboForwardBatchPreparer.prepare(forward_batch) | ||
| self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) | ||
|
|
||
| if lora_paths is not None: | ||
| self.model_runner.lora_manager.prepare_lora_batch(forward_batch) | ||
|
|
@@ -581,7 +579,13 @@ def replay_prepare( | |
| self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) | ||
| self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) | ||
| self.positions[:raw_num_token].copy_(forward_batch.positions) | ||
| self.num_token_non_padded[...] = len(forward_batch.input_ids) | ||
| num_token_non_padded = len(forward_batch.input_ids) | ||
| self.num_token_non_padded[...] = num_token_non_padded | ||
| self.tbo_plugin.replay_prepare( | ||
| forward_mode=forward_batch.forward_mode, | ||
| bs=bs, | ||
| num_token_non_padded=num_token_non_padded, | ||
| ) | ||
|
Comment on lines
+582
to
+588
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section introduces a call to |
||
| if forward_batch.seq_lens_cpu is not None: | ||
| if bs != raw_bs: | ||
| self.seq_lens_cpu.fill_(1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,25 +85,54 @@ def compute_split_token_index( | |
| raise NotImplementedError | ||
|
|
||
|
|
||
| def compute_split_indices_for_cuda_graph_replay( | ||
| forward_mode: ForwardMode, | ||
| cuda_graph_num_tokens: int, | ||
| ): | ||
| forward_mode_for_tbo_split = ( | ||
| forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE | ||
| ) | ||
| tbo_split_seq_index = compute_split_seq_index( | ||
| forward_mode=forward_mode_for_tbo_split, | ||
| num_tokens=cuda_graph_num_tokens, | ||
| extend_lens=None, | ||
| ) | ||
| tbo_split_token_index = compute_split_token_index( | ||
| split_seq_index=tbo_split_seq_index, | ||
| forward_mode=forward_mode_for_tbo_split, | ||
| extend_seq_lens=None, | ||
| ) | ||
| return tbo_split_seq_index, tbo_split_token_index | ||
|
|
||
|
|
||
| # -------------------------------- Preparation --------------------------------------- | ||
|
|
||
|
|
||
| class TboCudaGraphRunnerUtils: | ||
| @staticmethod | ||
| def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int): | ||
| if that.model_runner.server_args.enable_two_batch_overlap: | ||
| tbo_split_seq_index = compute_split_seq_index( | ||
| forward_mode=that.capture_forward_mode, | ||
| num_tokens=num_tokens, | ||
| extend_lens=None, | ||
| ) | ||
| # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true | ||
| assert ( | ||
| tbo_split_seq_index is not None | ||
| ), f"{that.capture_forward_mode=} {num_tokens=}" | ||
| else: | ||
| tbo_split_seq_index = None | ||
| return tbo_split_seq_index | ||
| class TboCudaGraphRunnerPlugin: | ||
| def __init__(self): | ||
| pass # TODO add logic here | ||
|
Comment on lines
+112
to
+113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Could you clarify if there's any initialization logic planned for this plugin in the near future, or if this |
||
|
|
||
| def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): | ||
| if not global_server_args_dict["enable_two_batch_overlap"]: | ||
| return | ||
|
|
||
| batch.tbo_split_seq_index = compute_split_seq_index( | ||
| forward_mode=batch.forward_mode, | ||
| num_tokens=num_tokens, | ||
| extend_lens=None, | ||
| ) | ||
| # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true | ||
| assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" | ||
|
|
||
| TboForwardBatchPreparer.prepare(batch) | ||
|
|
||
| def replay_prepare( | ||
| self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int | ||
| ): | ||
| if not global_server_args_dict["enable_two_batch_overlap"]: | ||
| return | ||
|
|
||
| pass # TODO add logic here | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The What is the intended purpose of this method? Is there specific TBO-related preparation logic that needs to be executed during CUDA graph replay that is planned to be added here? If so, this PR might be setting up the structure for it. Clarifying the scope of this |
||
|
|
||
|
|
||
| class TboDPAttentionPreparer: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code computes
tbo_split_seq_indexhere. Now that this logic is moved toTboCudaGraphRunnerPlugin, is it correct to remove it from here? Please confirm that this change aligns with the intended behavior.