diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md
index 22f92109e1ff..3bab890f26ab 100644
--- a/docs/advanced_features/server_arguments.md
+++ b/docs/advanced_features/server_arguments.md
@@ -259,7 +259,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float |
| `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str |
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` |
-| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None |
+| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | |
+| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | |
## Ngram speculative decoding
| Argument | Description | Defaults | Options |
@@ -276,7 +277,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int |
-| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep` |
+| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index b409cfdbf4ae..9740095a5e3f 100755
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -594,8 +594,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
- max_seq = forward_batch.seq_lens_cpu.max().item()
-
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
cu_seqlens_q = torch.nn.functional.pad(
@@ -624,6 +622,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_decode_metadata.block_kv_indices = block_kv_indices
self.forward_decode_metadata.max_seq_len_k = int(max_seq)
+ self.forward_decode_metadata.batch_size = bs
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
@@ -863,6 +862,14 @@ def forward_decode(
or self.forward_decode_metadata
)
+ # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
+ # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
+ # and padding logic in prepare_mlp_sync_batch to avoid this
+ batch_size = getattr(metadata, "batch_size", None)
+ if batch_size is not None and batch_size < forward_batch.batch_size:
+ self.init_forward_metadata(forward_batch)
+ metadata = forward_batch.decode_trtllm_mla_metadata
+
# Scale computation for TRTLLM MLA kernel BMM1 operation:
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
# Scale components:
@@ -976,6 +983,14 @@ def forward_extend(
or self.forward_decode_metadata
)
+ # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
+ # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
+ # and padding logic in prepare_mlp_sync_batch to avoid this
+ batch_size = getattr(metadata, "batch_size", None)
+ if batch_size is not None and batch_size < forward_batch.batch_size:
+ self.init_forward_metadata(forward_batch)
+ metadata = forward_batch.decode_trtllm_mla_metadata
+
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size
@@ -997,27 +1012,6 @@ def forward_extend(
)
else:
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
- # Check if we're in CUDA graph mode (buffers are pre-allocated)
- if self.padded_q_buffer is not None:
- # Use pre-allocated buffer for CUDA graph compatibility
- padded_q = self.padded_q_buffer[
- :bs, : metadata.max_seq_len_q, :, :
- ].to(dtype=q.dtype)
- else:
- # Dynamic allocation for non-CUDA graph mode
- padded_q = torch.zeros(
- bs,
- metadata.max_seq_len_q,
- layer.tp_q_head_num,
- layer.head_dim,
- dtype=q.dtype,
- device=q.device,
- )
- q = self.pad_draft_extend_query(
- q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
- )
-
- # TODO may use `mla_rope_quantize_fp8` fusion
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
assert kv_cache.dtype == self.data_type
@@ -1034,15 +1028,6 @@ def forward_extend(
bmm1_scale=bmm1_scale,
)
- # Reshape output directly without slicing
-
- if forward_batch.forward_mode.is_draft_extend(include_v2=True):
- raw_out = self.unpad_draft_extend_output(
- raw_out,
- metadata.cu_seqlens_q,
- metadata.seq_lens_q,
- metadata.sum_seq_lens_q,
- )
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py
index 85d28b148d73..70466bb20838 100644
--- a/python/sglang/srt/layers/moe/utils.py
+++ b/python/sglang/srt/layers/moe/utils.py
@@ -122,6 +122,7 @@ def is_auto(self) -> bool:
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
+SPECULATIVE_MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None
@@ -135,6 +136,7 @@ def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global SPECULATIVE_MOE_RUNNER_BACKEND
+ global SPECULATIVE_MOE_A2A_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
@@ -150,6 +152,11 @@ def initialize_moe_config(server_args: ServerArgs):
if server_args.speculative_moe_runner_backend is not None
else MOE_RUNNER_BACKEND
)
+ SPECULATIVE_MOE_A2A_BACKEND = (
+ MoeA2ABackend(server_args.speculative_moe_a2a_backend)
+ if server_args.speculative_moe_a2a_backend is not None
+ else MOE_A2A_BACKEND
+ )
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
@@ -189,6 +196,16 @@ def get_speculative_moe_runner_backend() -> MoeRunnerBackend:
return SPECULATIVE_MOE_RUNNER_BACKEND
+def get_speculative_moe_a2a_backend() -> MoeA2ABackend:
+ global SPECULATIVE_MOE_A2A_BACKEND
+ if SPECULATIVE_MOE_A2A_BACKEND is None:
+ logger.warning(
+ "SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend"
+ )
+ SPECULATIVE_MOE_A2A_BACKEND = MoeA2ABackend.NONE
+ return SPECULATIVE_MOE_A2A_BACKEND
+
+
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
@@ -258,6 +275,21 @@ def speculative_moe_backend_context():
MOE_RUNNER_BACKEND = original_backend
+@contextmanager
+def speculative_moe_a2a_backend_context():
+ """
+ Context manager to temporarily use the speculative MoE A2A backend for draft model operations.
+ This ensures that draft models in speculative decoding use the configured speculative A2A backend.
+ """
+ global MOE_A2A_BACKEND
+ original_backend = MOE_A2A_BACKEND
+ try:
+ MOE_A2A_BACKEND = MoeA2ABackend.NONE
+ yield
+ finally:
+ MOE_A2A_BACKEND = original_backend
+
+
# The type of method in top-K routing, for use in torch custom op
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 8baddc56f0c5..3437f2ef869d 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -764,7 +764,12 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
bs = self.batch_size
- if self.forward_mode.is_decode():
+ if (
+ self.forward_mode.is_decode()
+ or self.forward_mode.is_target_verify()
+ or self.forward_mode.is_draft_extend(include_v2=True)
+ or self.forward_mode.is_idle()
+ ):
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 737dd5cf10ba..a6a38507b3e5 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -2502,7 +2502,9 @@ def forward_normal_chunked_kv_prepare(
)
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
- has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
+ has_extend_prefix = forward_batch.extend_prefix_lens_cpu is not None and any(
+ forward_batch.extend_prefix_lens_cpu
+ )
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 69440d16a097..95577981cf5e 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -167,6 +167,8 @@
"cutlass",
]
+MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"]
+
MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]
@@ -394,6 +396,7 @@ class ServerArgs:
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
speculative_moe_runner_backend: Optional[str] = None
+ speculative_moe_a2a_backend: Optional[str] = None
# Speculative decoding (ngram)
speculative_ngram_min_match_window_size: int = 1
@@ -3007,6 +3010,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.speculative_moe_runner_backend,
help="Choose the runner backend for MoE in speculative decoding.",
)
+ parser.add_argument(
+ "--speculative-moe-a2a-backend",
+ type=str,
+ choices=MOE_A2A_BACKEND_CHOICES,
+ default=ServerArgs.speculative_moe_a2a_backend,
+ help="Choose the backend for MoE A2A in speculative decoding",
+ )
# Speculative decoding (ngram)
parser.add_argument(
@@ -3065,7 +3075,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--moe-a2a-backend",
type=str,
- choices=["none", "deepep", "mooncake", "ascend_fuseep"],
+ choices=MOE_A2A_BACKEND_CHOICES,
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
index 51e7bf060856..b3d72df05cfe 100644
--- a/python/sglang/srt/speculative/eagle_info.py
+++ b/python/sglang/srt/speculative/eagle_info.py
@@ -67,6 +67,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None
+ # Shape info for padding
+ num_tokens_per_batch: int = -1
+
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_VERIFY)
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index 0add45539222..07e3798f1f09 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -10,7 +10,10 @@
)
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
-from sglang.srt.layers.moe.utils import speculative_moe_backend_context
+from sglang.srt.layers.moe.utils import (
+ speculative_moe_a2a_backend_context,
+ speculative_moe_backend_context,
+)
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -132,7 +135,9 @@ def __init__(
ctx = draft_tp_context(get_attention_tp_group())
else:
ctx = empty_context()
- with ctx, speculative_moe_backend_context():
+ with (
+ ctx
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
super().__init__(
server_args=server_args,
gpu_id=gpu_id,
@@ -183,7 +188,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
@@ -276,7 +281,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
@@ -289,7 +294,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
else:
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
@@ -297,7 +302,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
@@ -665,6 +670,7 @@ def clear_cache_pool(self):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
+ spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.return_hidden_states = False
batch.forward_mode = (
ForwardMode.TARGET_VERIFY
diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py
index 579d8f57c89d..629ddf27ef02 100644
--- a/python/sglang/srt/speculative/eagle_worker_v2.py
+++ b/python/sglang/srt/speculative/eagle_worker_v2.py
@@ -12,7 +12,10 @@
from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import (
EAGLEDraftNpuGraphRunner,
)
-from sglang.srt.layers.moe.utils import speculative_moe_backend_context
+from sglang.srt.layers.moe.utils import (
+ speculative_moe_a2a_backend_context,
+ speculative_moe_backend_context,
+)
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
@@ -112,7 +115,7 @@ def __init__(
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
- with empty_context(), speculative_moe_backend_context():
+ with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# Init draft worker
self.draft_worker = TpModelWorker(
server_args=server_args,
@@ -140,7 +143,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
@@ -611,12 +614,15 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
- batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
- model_worker_batch,
- batch_output.logits_output.hidden_states,
- batch_output.next_token_ids,
- )
- return batch_output
+ with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ batch_output.next_draft_input = (
+ self.draft_worker._draft_extend_for_prefill(
+ model_worker_batch,
+ batch_output.logits_output.hidden_states,
+ batch_output.next_token_ids,
+ )
+ )
+ return batch_output
else:
if model_worker_batch.spec_info is None:
model_worker_batch.spec_info = EagleDraftInput.create_idle_input(
@@ -626,11 +632,17 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
- verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
+ with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ verify_input: EagleVerifyInput = self.draft_worker.draft(
+ model_worker_batch
+ )
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch)
- self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
+ with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ self.draft_worker._draft_extend_for_decode(
+ model_worker_batch, batch_output
+ )
return batch_output
def verify(self, batch: ModelWorkerBatch):
@@ -643,6 +655,7 @@ def verify(self, batch: ModelWorkerBatch):
# Parse args
verify_input: EagleVerifyInput = batch.spec_info
+ verify_input.num_tokens_per_batch = self.speculative_num_steps + 1
bs = len(batch.seq_lens)
# Batch 1: Target verify
diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py
index 230a6ed00cd5..e1f331975bde 100644
--- a/python/sglang/srt/speculative/standalone_worker.py
+++ b/python/sglang/srt/speculative/standalone_worker.py
@@ -3,7 +3,10 @@
import torch
-from sglang.srt.layers.moe.utils import speculative_moe_backend_context
+from sglang.srt.layers.moe.utils import (
+ speculative_moe_a2a_backend_context,
+ speculative_moe_backend_context,
+)
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -67,7 +70,7 @@ def __init__(
self.hot_token_id = None
# Init draft worker
- with empty_context(), speculative_moe_backend_context():
+ with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
TpModelWorker.__init__(
self,
server_args=server_args,
@@ -91,7 +94,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context():
+ ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py
index 08a13f604f81..6fa0b2404ba0 100644
--- a/python/sglang/srt/utils/common.py
+++ b/python/sglang/srt/utils/common.py
@@ -2797,6 +2797,8 @@ def require_mlp_tp_gather(server_args: ServerArgs):
"""
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
"""
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
+
if server_args.enable_dp_attention:
assert server_args.dp_size > 1, "dp_size must be greater than 1"
if (
@@ -2805,7 +2807,7 @@ def require_mlp_tp_gather(server_args: ServerArgs):
return True
elif not server_args.enable_dp_lm_head:
return True
- elif server_args.moe_a2a_backend == "none":
+ elif get_moe_a2a_backend().is_none():
return True
else:
return (
@@ -2820,8 +2822,10 @@ def require_attn_tp_gather(server_args: ServerArgs):
"""
Check if the input of attention is scattered.
"""
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
+
assert server_args.moe_dense_tp_size in [1, None]
- if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
+ if not get_moe_a2a_backend().is_none() or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else: