Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions python/sglang/srt/layers/attention/tbo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,40 @@ def _init_forward_metadata_cuda_graph_children(
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs
if not forward_mode.is_target_verify():
if fn_name == "init_forward_metadata_capture_cuda_graph":
assert (
capture_num_tokens == bs
), "Only support num_tokens==bs currently unless target-verify mode"
num_tokens = bs
else:
draft_token_num = spec_info.draft_token_num
if capture_num_tokens is not None:
assert (
draft_token_num * bs == capture_num_tokens
), f"For target-verify mode, num_tokens ({capture_num_tokens}) should be equal to draft_token_num ({draft_token_num}) * bs ({bs})"
num_tokens = draft_token_num * bs

tbo_split_seq_index, tbo_split_token_index = (
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
cuda_graph_num_tokens=num_tokens,
spec_info=spec_info,
)
)

num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right
bs_child_left = (
num_tokens_child_left
if not forward_mode.is_target_verify()
else tbo_split_seq_index
)
bs_child_right = (
num_tokens_child_right
if not forward_mode.is_target_verify()
else bs - tbo_split_seq_index
)

assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
Expand Down Expand Up @@ -198,7 +217,6 @@ def _init_forward_metadata_cuda_graph_split(
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"

ans = dict(
bs=output_bs,
Expand All @@ -208,16 +226,28 @@ def _init_forward_metadata_cuda_graph_split(
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
spec_info=spec_info,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe split spec_info into two microbatches

)

if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
ans.update(
dict(
num_tokens=output_bs,
if forward_mode.is_target_verify():
assert (
capture_num_tokens == bs * spec_info.draft_token_num
), "For target-verify mode, num_tokens should be equal to draft_token_num * bs"
ans.update(
dict(
num_tokens=capture_num_tokens,
)
)
else:
assert (
capture_num_tokens == bs
), "Only support num_tokens==bs currently unless target-verify mode"
ans.update(
dict(
num_tokens=output_bs,
)
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]
ans.update(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def replay_prepare(
forward_mode=forward_batch.forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
spec_info=forward_batch.spec_info,
)

# Attention backend
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def init_new(

if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
TboForwardBatchPreparer.prepare(ret)
if not model_runner.is_draft_worker:
TboForwardBatchPreparer.prepare(ret)
else:
ret.tbo_split_seq_index = None
return ret

# Override the positions with spec_info
Expand Down Expand Up @@ -397,7 +400,10 @@ def init_new(
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)

TboForwardBatchPreparer.prepare(ret)
if not model_runner.is_draft_worker:
TboForwardBatchPreparer.prepare(ret)
else:
ret.tbo_split_seq_index = None

return ret

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/operations_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_deepseek_blog_decode(layer)
elif forward_mode == ForwardMode.TARGET_VERIFY:
return _compute_moe_deepseek_blog_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")

Expand Down Expand Up @@ -148,6 +150,8 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_qwen3_decode(layer)
elif forward_mode == ForwardMode.TARGET_VERIFY:
return _compute_moe_qwen3_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")

Expand Down
91 changes: 75 additions & 16 deletions python/sglang/srt/two_batch_overlap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Union

import torch

Expand All @@ -16,6 +16,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var

_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
Expand All @@ -31,7 +32,11 @@ def compute_split_seq_index(
forward_mode: "ForwardMode",
num_tokens: int,
extend_lens: Optional[Sequence[int]],
draft_token_num_per_batch: Optional[int],
) -> Optional[int]:
if forward_mode.is_target_verify():
assert draft_token_num_per_batch is not None
return (num_tokens // draft_token_num_per_batch) // 2
if forward_mode.is_extend():
assert extend_lens is not None
return _split_array_by_half_sum(extend_lens)
Expand Down Expand Up @@ -67,8 +72,12 @@ def compute_split_token_index(
split_seq_index: int,
forward_mode: "ForwardMode",
extend_seq_lens: Optional[Sequence[int]],
draft_token_num_per_batch: Optional[int],
) -> int:
if forward_mode.is_extend():
if forward_mode.is_target_verify():
assert draft_token_num_per_batch is not None
return split_seq_index * draft_token_num_per_batch
elif forward_mode.is_extend():
assert extend_seq_lens is not None
return sum(extend_seq_lens[:split_seq_index])
elif forward_mode.is_decode():
Expand All @@ -83,19 +92,26 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
if forward_mode.is_target_verify():
draft_token_num_per_batch = spec_info.draft_token_num
else:
draft_token_num_per_batch = None
tbo_split_seq_index = compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=cuda_graph_num_tokens,
extend_lens=None,
draft_token_num_per_batch=draft_token_num_per_batch,
)
tbo_split_token_index = compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
draft_token_num_per_batch=draft_token_num_per_batch,
)
return tbo_split_seq_index, tbo_split_token_index

Expand All @@ -110,11 +126,16 @@ def __init__(self):
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
return
if batch.forward_mode.is_target_verify():
draft_token_num_per_batch = batch.spec_info.draft_token_num
else:
draft_token_num_per_batch = None

batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens,
extend_lens=None,
draft_token_num_per_batch=draft_token_num_per_batch,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
Expand All @@ -129,13 +150,21 @@ def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
)

def replay_prepare(
self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
self,
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_target_verify():
token_num_per_batch = spec_info.draft_token_num
else:
token_num_per_batch = 1
tbo_split_seq_index, tbo_split_token_index = (
compute_split_indices_for_cuda_graph_replay(
forward_mode=forward_mode,
# TODO support bs!=num_tokens
cuda_graph_num_tokens=bs,
cuda_graph_num_tokens=bs * token_num_per_batch,
spec_info=spec_info,
)
)

Expand All @@ -154,14 +183,27 @@ def prepare_all_gather(
self.enable_two_batch_overlap = enable_two_batch_overlap

if local_batch is not None:
if local_batch.forward_mode.is_target_verify():
draft_token_num_per_batch = local_batch.spec_info.draft_token_num
num_tokens = local_batch.batch_size() * draft_token_num_per_batch
else:
draft_token_num_per_batch = None
if local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens
self.local_tbo_split_seq_index = compute_split_seq_index(
forward_mode=local_batch.forward_mode,
num_tokens=local_batch.input_ids.shape[0],
num_tokens=num_tokens,
extend_lens=local_batch.extend_lens,
draft_token_num_per_batch=draft_token_num_per_batch,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
local_batch.forward_mode.is_extend()
(
local_batch.forward_mode.is_extend()
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.low_latency)
)
Expand Down Expand Up @@ -242,7 +284,9 @@ def prepare_raw(
f"TboForwardBatchPreparer.prepare "
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
f"tbo_split_token_index={tbo_split_token_index} "
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
f"bs={batch.batch_size} "
f"forward_mode={batch.forward_mode}"
)

assert isinstance(batch.attn_backend, TboAttnBackend)
Expand Down Expand Up @@ -286,6 +330,9 @@ def filter_batch(
output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor,
):
assert (
end_token_index >= start_token_index
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size

Expand Down Expand Up @@ -317,6 +364,9 @@ def filter_batch(
old_value = getattr(batch, key)
if old_value is None:
continue
elif batch.forward_mode.is_target_verify() and key.startswith("extend_"):
output_dict[key] = None
continue
assert (
len(old_value) == num_seqs
), f"{key=} {old_value=} {num_seqs=} {batch=}"
Expand All @@ -336,11 +386,11 @@ def filter_batch(
"mrope_positions", # only used by qwen2-vl, thus not care
]:
output_dict[key] = getattr(batch, key)

assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
if not batch.forward_mode.is_target_verify():
assert (
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
== batch.extend_num_tokens
), f"{batch=}"
extend_num_tokens = _compute_extend_num_tokens(
output_dict["input_ids"], output_dict["forward_mode"]
)
Expand Down Expand Up @@ -416,18 +466,27 @@ def compute_tbo_children_num_token_non_padded_raw(

@classmethod
def _compute_split_token_index(cls, batch: ForwardBatch):
if batch.forward_mode.is_target_verify():
draft_token_num_per_batch = batch.spec_info.draft_token_num
else:
draft_token_num_per_batch = None
return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
draft_token_num_per_batch=draft_token_num_per_batch,
)


def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend():
return input_ids.shape[0]
elif forward_mode.is_decode() or forward_mode.is_idle():
if (
forward_mode.is_decode()
or forward_mode.is_idle()
or forward_mode.is_target_verify()
):
return None
elif forward_mode.is_extend():
return input_ids.shape[0]
raise NotImplementedError


Expand Down
Loading
Loading