Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions python/sglang/srt/layers/attention/tbo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,27 @@ def _init_forward_metadata_cuda_graph_children(
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs
assert (
capture_num_tokens == bs * token_num_per_seq
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
num_tokens = bs * token_num_per_seq

tbo_split_seq_index, tbo_split_token_index = (
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
cuda_graph_num_tokens=num_tokens,
spec_info=spec_info,
)
)

num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right
bs_child_left = tbo_split_seq_index
bs_child_right = bs - bs_child_left

assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
Expand Down Expand Up @@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[EagleVerifyInput],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info
)
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"
if spec_info is not None:
output_spec_info = two_batch_overlap.split_spec_info(
spec_info=spec_info,
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
start_token_index=(
seq_slice.start * token_num_per_seq
if seq_slice.start is not None
else 0
),
end_token_index=(
seq_slice.stop * token_num_per_seq
if seq_slice.stop is not None
else bs * token_num_per_seq
),
)

else:
output_spec_info = None
ans = dict(
bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice],
Expand All @@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
spec_info=output_spec_info,
)

if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
assert (
capture_num_tokens == bs * token_num_per_seq
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
ans.update(
dict(
num_tokens=output_bs,
num_tokens=output_bs * token_num_per_seq,
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def replay_prepare(
forward_mode=self.capture_forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
spec_info=forward_batch.spec_info,
)
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
forward_batch.spec_info.custom_mask = self.custom_mask
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def init_new(

if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)
return ret

# Override the positions with spec_info
Expand Down Expand Up @@ -397,7 +399,9 @@ def init_new(
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)

TboForwardBatchPreparer.prepare(ret)
TboForwardBatchPreparer.prepare(
ret, is_draft_worker=model_runner.is_draft_worker
)

return ret

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def init_cublas(self):

def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap:
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else:
self.attn_backend = self._get_attention_backend()
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/operations_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_deepseek_blog_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
Expand Down Expand Up @@ -146,7 +148,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
elif (
forward_mode == ForwardMode.DECODE or forward_mode == ForwardMode.TARGET_VERIFY
):
return _compute_moe_qwen3_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
Expand Down
Loading
Loading