diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md
index b1e0f1ebb314..891cbacfaaf4 100644
--- a/docs/advanced_features/server_arguments.md
+++ b/docs/advanced_features/server_arguments.md
@@ -259,8 +259,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float |
| `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str |
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` |
-| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | |
-| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | |
+| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None |
## Ngram speculative decoding
| Argument | Description | Defaults | Options |
@@ -277,7 +276,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int |
-| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`|
+| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep` |
| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index 9740095a5e3f..b409cfdbf4ae 100755
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -594,6 +594,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
seq_lens = seq_lens + self.num_draft_tokens
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
+ max_seq = forward_batch.seq_lens_cpu.max().item()
+
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
cu_seqlens_q = torch.nn.functional.pad(
@@ -622,7 +624,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_decode_metadata.block_kv_indices = block_kv_indices
self.forward_decode_metadata.max_seq_len_k = int(max_seq)
- self.forward_decode_metadata.batch_size = bs
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
@@ -862,14 +863,6 @@ def forward_decode(
or self.forward_decode_metadata
)
- # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
- # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
- # and padding logic in prepare_mlp_sync_batch to avoid this
- batch_size = getattr(metadata, "batch_size", None)
- if batch_size is not None and batch_size < forward_batch.batch_size:
- self.init_forward_metadata(forward_batch)
- metadata = forward_batch.decode_trtllm_mla_metadata
-
# Scale computation for TRTLLM MLA kernel BMM1 operation:
# The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
# Scale components:
@@ -983,14 +976,6 @@ def forward_extend(
or self.forward_decode_metadata
)
- # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch
- # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends
- # and padding logic in prepare_mlp_sync_batch to avoid this
- batch_size = getattr(metadata, "batch_size", None)
- if batch_size is not None and batch_size < forward_batch.batch_size:
- self.init_forward_metadata(forward_batch)
- metadata = forward_batch.decode_trtllm_mla_metadata
-
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size
@@ -1012,6 +997,27 @@ def forward_extend(
)
else:
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
+ # Check if we're in CUDA graph mode (buffers are pre-allocated)
+ if self.padded_q_buffer is not None:
+ # Use pre-allocated buffer for CUDA graph compatibility
+ padded_q = self.padded_q_buffer[
+ :bs, : metadata.max_seq_len_q, :, :
+ ].to(dtype=q.dtype)
+ else:
+ # Dynamic allocation for non-CUDA graph mode
+ padded_q = torch.zeros(
+ bs,
+ metadata.max_seq_len_q,
+ layer.tp_q_head_num,
+ layer.head_dim,
+ dtype=q.dtype,
+ device=q.device,
+ )
+ q = self.pad_draft_extend_query(
+ q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
+ )
+
+ # TODO may use `mla_rope_quantize_fp8` fusion
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
assert kv_cache.dtype == self.data_type
@@ -1028,6 +1034,15 @@ def forward_extend(
bmm1_scale=bmm1_scale,
)
+ # Reshape output directly without slicing
+
+ if forward_batch.forward_mode.is_draft_extend(include_v2=True):
+ raw_out = self.unpad_draft_extend_output(
+ raw_out,
+ metadata.cu_seqlens_q,
+ metadata.seq_lens_q,
+ metadata.sum_seq_lens_q,
+ )
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py
index 70466bb20838..85d28b148d73 100644
--- a/python/sglang/srt/layers/moe/utils.py
+++ b/python/sglang/srt/layers/moe/utils.py
@@ -122,7 +122,6 @@ def is_auto(self) -> bool:
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
-SPECULATIVE_MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None
@@ -136,7 +135,6 @@ def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global SPECULATIVE_MOE_RUNNER_BACKEND
- global SPECULATIVE_MOE_A2A_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
@@ -152,11 +150,6 @@ def initialize_moe_config(server_args: ServerArgs):
if server_args.speculative_moe_runner_backend is not None
else MOE_RUNNER_BACKEND
)
- SPECULATIVE_MOE_A2A_BACKEND = (
- MoeA2ABackend(server_args.speculative_moe_a2a_backend)
- if server_args.speculative_moe_a2a_backend is not None
- else MOE_A2A_BACKEND
- )
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
@@ -196,16 +189,6 @@ def get_speculative_moe_runner_backend() -> MoeRunnerBackend:
return SPECULATIVE_MOE_RUNNER_BACKEND
-def get_speculative_moe_a2a_backend() -> MoeA2ABackend:
- global SPECULATIVE_MOE_A2A_BACKEND
- if SPECULATIVE_MOE_A2A_BACKEND is None:
- logger.warning(
- "SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend"
- )
- SPECULATIVE_MOE_A2A_BACKEND = MoeA2ABackend.NONE
- return SPECULATIVE_MOE_A2A_BACKEND
-
-
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
@@ -275,21 +258,6 @@ def speculative_moe_backend_context():
MOE_RUNNER_BACKEND = original_backend
-@contextmanager
-def speculative_moe_a2a_backend_context():
- """
- Context manager to temporarily use the speculative MoE A2A backend for draft model operations.
- This ensures that draft models in speculative decoding use the configured speculative A2A backend.
- """
- global MOE_A2A_BACKEND
- original_backend = MOE_A2A_BACKEND
- try:
- MOE_A2A_BACKEND = MoeA2ABackend.NONE
- yield
- finally:
- MOE_A2A_BACKEND = original_backend
-
-
# The type of method in top-K routing, for use in torch custom op
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index d4ed92d0ec2d..4de8a4ef7df8 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -772,12 +772,7 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
bs = self.batch_size
- if (
- self.forward_mode.is_decode()
- or self.forward_mode.is_target_verify()
- or self.forward_mode.is_draft_extend(include_v2=True)
- or self.forward_mode.is_idle()
- ):
+ if self.forward_mode.is_decode():
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index ea8ec92a51d9..9f866499d987 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -2520,9 +2520,7 @@ def forward_normal_chunked_kv_prepare(
)
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
- has_extend_prefix = forward_batch.extend_prefix_lens_cpu is not None and any(
- forward_batch.extend_prefix_lens_cpu
- )
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 68bd6709cb52..f95053206eb8 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -168,8 +168,6 @@
"cutlass",
]
-MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"]
-
MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]
@@ -397,7 +395,6 @@ class ServerArgs:
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
speculative_moe_runner_backend: Optional[str] = None
- speculative_moe_a2a_backend: Optional[str] = None
# Speculative decoding (ngram)
speculative_ngram_min_match_window_size: int = 1
@@ -3039,13 +3036,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.speculative_moe_runner_backend,
help="Choose the runner backend for MoE in speculative decoding.",
)
- parser.add_argument(
- "--speculative-moe-a2a-backend",
- type=str,
- choices=MOE_A2A_BACKEND_CHOICES,
- default=ServerArgs.speculative_moe_a2a_backend,
- help="Choose the backend for MoE A2A in speculative decoding",
- )
# Speculative decoding (ngram)
parser.add_argument(
@@ -3104,7 +3094,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--moe-a2a-backend",
type=str,
- choices=MOE_A2A_BACKEND_CHOICES,
+ choices=["none", "deepep", "mooncake", "ascend_fuseep"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
index b3d72df05cfe..51e7bf060856 100644
--- a/python/sglang/srt/speculative/eagle_info.py
+++ b/python/sglang/srt/speculative/eagle_info.py
@@ -67,9 +67,6 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None
- # Shape info for padding
- num_tokens_per_batch: int = -1
-
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_VERIFY)
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index 07e3798f1f09..0add45539222 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -10,10 +10,7 @@
)
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
-from sglang.srt.layers.moe.utils import (
- speculative_moe_a2a_backend_context,
- speculative_moe_backend_context,
-)
+from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -135,9 +132,7 @@ def __init__(
ctx = draft_tp_context(get_attention_tp_group())
else:
ctx = empty_context()
- with (
- ctx
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ with ctx, speculative_moe_backend_context():
super().__init__(
server_args=server_args,
gpu_id=gpu_id,
@@ -188,7 +183,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
@@ -281,7 +276,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
@@ -294,7 +289,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
else:
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
@@ -302,7 +297,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
@@ -670,7 +665,6 @@ def clear_cache_pool(self):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
- spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.return_hidden_states = False
batch.forward_mode = (
ForwardMode.TARGET_VERIFY
diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py
index 629ddf27ef02..579d8f57c89d 100644
--- a/python/sglang/srt/speculative/eagle_worker_v2.py
+++ b/python/sglang/srt/speculative/eagle_worker_v2.py
@@ -12,10 +12,7 @@
from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import (
EAGLEDraftNpuGraphRunner,
)
-from sglang.srt.layers.moe.utils import (
- speculative_moe_a2a_backend_context,
- speculative_moe_backend_context,
-)
+from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
@@ -115,7 +112,7 @@ def __init__(
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
- with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ with empty_context(), speculative_moe_backend_context():
# Init draft worker
self.draft_worker = TpModelWorker(
server_args=server_args,
@@ -143,7 +140,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
@@ -614,15 +611,12 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
- with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
- batch_output.next_draft_input = (
- self.draft_worker._draft_extend_for_prefill(
- model_worker_batch,
- batch_output.logits_output.hidden_states,
- batch_output.next_token_ids,
- )
- )
- return batch_output
+ batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
+ model_worker_batch,
+ batch_output.logits_output.hidden_states,
+ batch_output.next_token_ids,
+ )
+ return batch_output
else:
if model_worker_batch.spec_info is None:
model_worker_batch.spec_info = EagleDraftInput.create_idle_input(
@@ -632,17 +626,11 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
- with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
- verify_input: EagleVerifyInput = self.draft_worker.draft(
- model_worker_batch
- )
+ verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch)
- with speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
- self.draft_worker._draft_extend_for_decode(
- model_worker_batch, batch_output
- )
+ self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
return batch_output
def verify(self, batch: ModelWorkerBatch):
@@ -655,7 +643,6 @@ def verify(self, batch: ModelWorkerBatch):
# Parse args
verify_input: EagleVerifyInput = batch.spec_info
- verify_input.num_tokens_per_batch = self.speculative_num_steps + 1
bs = len(batch.seq_lens)
# Batch 1: Target verify
diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py
index e1f331975bde..230a6ed00cd5 100644
--- a/python/sglang/srt/speculative/standalone_worker.py
+++ b/python/sglang/srt/speculative/standalone_worker.py
@@ -3,10 +3,7 @@
import torch
-from sglang.srt.layers.moe.utils import (
- speculative_moe_a2a_backend_context,
- speculative_moe_backend_context,
-)
+from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -70,7 +67,7 @@ def __init__(
self.hot_token_id = None
# Init draft worker
- with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ with empty_context(), speculative_moe_backend_context():
TpModelWorker.__init__(
self,
server_args=server_args,
@@ -94,7 +91,7 @@ def __init__(
)
with self.draft_tp_context(
self.draft_model_runner.tp_group
- ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
+ ), speculative_moe_backend_context():
self.init_attention_backend()
self.init_cuda_graphs()
diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py
index cf7264225dc6..78d40064f3c1 100644
--- a/python/sglang/srt/utils/common.py
+++ b/python/sglang/srt/utils/common.py
@@ -2799,8 +2799,6 @@ def require_mlp_tp_gather(server_args: ServerArgs):
"""
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
"""
- from sglang.srt.layers.moe.utils import get_moe_a2a_backend
-
if server_args.enable_dp_attention:
assert server_args.dp_size > 1, "dp_size must be greater than 1"
if (
@@ -2809,7 +2807,7 @@ def require_mlp_tp_gather(server_args: ServerArgs):
return True
elif not server_args.enable_dp_lm_head:
return True
- elif get_moe_a2a_backend().is_none():
+ elif server_args.moe_a2a_backend == "none":
return True
else:
return (
@@ -2824,10 +2822,8 @@ def require_attn_tp_gather(server_args: ServerArgs):
"""
Check if the input of attention is scattered.
"""
- from sglang.srt.layers.moe.utils import get_moe_a2a_backend
-
assert server_args.moe_dense_tp_size in [1, None]
- if not get_moe_a2a_backend().is_none() or server_args.moe_dense_tp_size == 1:
+ if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else: