Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float |
| `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str |
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` |
| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None |
| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | |
| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | |

## Ngram speculative decoding
| Argument | Description | Defaults | Options |
Expand All @@ -276,7 +277,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--expert-parallel-size`<br>`--ep-size`<br>`--ep` | The expert parallelism size. | `1` | Type: int |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep` |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
Expand Down
49 changes: 17 additions & 32 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item()

sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
cu_seqlens_q = torch.nn.functional.pad(
Expand Down Expand Up @@ -624,6 +622,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

self.forward_decode_metadata.block_kv_indices = block_kv_indices
self.forward_decode_metadata.max_seq_len_k = int(max_seq)
self.forward_decode_metadata.batch_size = bs

forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
Expand Down Expand Up @@ -863,6 +862,14 @@ def forward_decode(
or self.forward_decode_metadata
)

# Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
# FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
# and padding logic in prepare_mlp_sync_batch to avoid this
batch_size = getattr(metadata, "batch_size", None)
if batch_size is not None and batch_size < forward_batch.batch_size:
self.init_forward_metadata(forward_batch)
metadata = forward_batch.decode_trtllm_mla_metadata

# Scale computation for TRTLLM MLA kernel BMM1 operation:
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
# Scale components:
Expand Down Expand Up @@ -976,6 +983,14 @@ def forward_extend(
or self.forward_decode_metadata
)

# Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
# FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
# and padding logic in prepare_mlp_sync_batch to avoid this
batch_size = getattr(metadata, "batch_size", None)
if batch_size is not None and batch_size < forward_batch.batch_size:
self.init_forward_metadata(forward_batch)
metadata = forward_batch.decode_trtllm_mla_metadata

# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size

Expand All @@ -997,27 +1012,6 @@ def forward_extend(
)
else:
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if self.padded_q_buffer is not None:
# Use pre-allocated buffer for CUDA graph compatibility
padded_q = self.padded_q_buffer[
:bs, : metadata.max_seq_len_q, :, :
].to(dtype=q.dtype)
else:
# Dynamic allocation for non-CUDA graph mode
padded_q = torch.zeros(
bs,
metadata.max_seq_len_q,
layer.tp_q_head_num,
layer.head_dim,
dtype=q.dtype,
device=q.device,
)
q = self.pad_draft_extend_query(
q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
)

# TODO may use `mla_rope_quantize_fp8` fusion
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
assert kv_cache.dtype == self.data_type

Expand All @@ -1034,15 +1028,6 @@ def forward_extend(
bmm1_scale=bmm1_scale,
)

# Reshape output directly without slicing

if forward_batch.forward_mode.is_draft_extend(include_v2=True):
raw_out = self.unpad_draft_extend_output(
raw_out,
metadata.cu_seqlens_q,
metadata.seq_lens_q,
metadata.sum_seq_lens_q,
)
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output

Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def is_auto(self) -> bool:
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
SPECULATIVE_MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None
Expand All @@ -135,6 +136,7 @@ def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global SPECULATIVE_MOE_RUNNER_BACKEND
global SPECULATIVE_MOE_A2A_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
Expand All @@ -150,6 +152,11 @@ def initialize_moe_config(server_args: ServerArgs):
if server_args.speculative_moe_runner_backend is not None
else MOE_RUNNER_BACKEND
)
SPECULATIVE_MOE_A2A_BACKEND = (
MoeA2ABackend(server_args.speculative_moe_a2a_backend)
if server_args.speculative_moe_a2a_backend is not None
else MOE_A2A_BACKEND
)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
Expand Down Expand Up @@ -189,6 +196,16 @@ def get_speculative_moe_runner_backend() -> MoeRunnerBackend:
return SPECULATIVE_MOE_RUNNER_BACKEND


def get_speculative_moe_a2a_backend() -> MoeA2ABackend:
global SPECULATIVE_MOE_A2A_BACKEND
if SPECULATIVE_MOE_A2A_BACKEND is None:
logger.warning(
"SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend"
)
SPECULATIVE_MOE_A2A_BACKEND = MoeA2ABackend.NONE
return SPECULATIVE_MOE_A2A_BACKEND


def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
Expand Down Expand Up @@ -258,6 +275,21 @@ def speculative_moe_backend_context():
MOE_RUNNER_BACKEND = original_backend


@contextmanager
def speculative_moe_a2a_backend_context():
"""
Context manager to temporarily use the speculative MoE A2A backend for draft model operations.
This ensures that draft models in speculative decoding use the configured speculative A2A backend.
"""
global MOE_A2A_BACKEND
original_backend = MOE_A2A_BACKEND
try:
MOE_A2A_BACKEND = MoeA2ABackend.NONE
yield
finally:
MOE_A2A_BACKEND = original_backend


# The type of method in top-K routing, for use in torch custom op
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,12 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):

bs = self.batch_size

if self.forward_mode.is_decode():
if (
self.forward_mode.is_decode()
or self.forward_mode.is_target_verify()
or self.forward_mode.is_draft_extend(include_v2=True)
or self.forward_mode.is_idle()
):
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,9 @@ def forward_normal_chunked_kv_prepare(
)

def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
has_extend_prefix = forward_batch.extend_prefix_lens_cpu is not None and any(
forward_batch.extend_prefix_lens_cpu
)
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
Expand Down
12 changes: 11 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@
"cutlass",
]

MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"]

MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]


Expand Down Expand Up @@ -394,6 +396,7 @@ class ServerArgs:
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
speculative_moe_runner_backend: Optional[str] = None
speculative_moe_a2a_backend: Optional[str] = None

# Speculative decoding (ngram)
speculative_ngram_min_match_window_size: int = 1
Expand Down Expand Up @@ -3007,6 +3010,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.speculative_moe_runner_backend,
help="Choose the runner backend for MoE in speculative decoding.",
)
parser.add_argument(
"--speculative-moe-a2a-backend",
type=str,
choices=MOE_A2A_BACKEND_CHOICES,
default=ServerArgs.speculative_moe_a2a_backend,
help="Choose the backend for MoE A2A in speculative decoding",
)

# Speculative decoding (ngram)
parser.add_argument(
Expand Down Expand Up @@ -3065,7 +3075,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--moe-a2a-backend",
type=str,
choices=["none", "deepep", "mooncake", "ascend_fuseep"],
choices=MOE_A2A_BACKEND_CHOICES,
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None

# Shape info for padding
num_tokens_per_batch: int = -1

def __post_init__(self):
super().__init__(SpecInputType.EAGLE_VERIFY)

Expand Down
18 changes: 12 additions & 6 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
)
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.layers.moe.utils import (
speculative_moe_a2a_backend_context,
speculative_moe_backend_context,
)
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ScheduleBatch
Expand Down Expand Up @@ -132,7 +135,9 @@ def __init__(
ctx = draft_tp_context(get_attention_tp_group())
else:
ctx = empty_context()
with ctx, speculative_moe_backend_context():
with (
ctx
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
super().__init__(
server_args=server_args,
gpu_id=gpu_id,
Expand Down Expand Up @@ -183,7 +188,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()

Expand Down Expand Up @@ -276,7 +281,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
Expand All @@ -289,15 +294,15 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
else:
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)

with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
Expand Down Expand Up @@ -665,6 +670,7 @@ def clear_cache_pool(self):

def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.return_hidden_states = False
batch.forward_mode = (
ForwardMode.TARGET_VERIFY
Expand Down
35 changes: 24 additions & 11 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import (
EAGLEDraftNpuGraphRunner,
)
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.layers.moe.utils import (
speculative_moe_a2a_backend_context,
speculative_moe_backend_context,
)
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
Expand Down Expand Up @@ -112,7 +115,7 @@ def __init__(
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
with empty_context(), speculative_moe_backend_context():
with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# Init draft worker
self.draft_worker = TpModelWorker(
server_args=server_args,
Expand Down Expand Up @@ -140,7 +143,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_runner.tp_group
), speculative_moe_backend_context():
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()

Expand Down Expand Up @@ -611,12 +614,15 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):

# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
)
return batch_output
with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
batch_output.next_draft_input = (
self.draft_worker._draft_extend_for_prefill(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
)
)
return batch_output
else:
if model_worker_batch.spec_info is None:
model_worker_batch.spec_info = EagleDraftInput.create_idle_input(
Expand All @@ -626,11 +632,17 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
verify_input: EagleVerifyInput = self.draft_worker.draft(
model_worker_batch
)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch)
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.draft_worker._draft_extend_for_decode(
model_worker_batch, batch_output
)
return batch_output

def verify(self, batch: ModelWorkerBatch):
Expand All @@ -643,6 +655,7 @@ def verify(self, batch: ModelWorkerBatch):

# Parse args
verify_input: EagleVerifyInput = batch.spec_info
verify_input.num_tokens_per_batch = self.speculative_num_steps + 1
bs = len(batch.seq_lens)

# Batch 1: Target verify
Expand Down
Loading
Loading