diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 22f92109e1ff..3bab890f26ab 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -259,7 +259,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float | | `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str | | `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` | -| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None | +| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | | +| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | | ## Ngram speculative decoding | Argument | Description | Defaults | Options | @@ -276,7 +277,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Argument | Description | Defaults | Options | | --- | --- | --- | --- | | `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep` | +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`| | `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | | `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index b409cfdbf4ae..9740095a5e3f 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -594,8 +594,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens = seq_lens + self.num_draft_tokens self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32) elif forward_batch.forward_mode.is_draft_extend(include_v2=True): - max_seq = forward_batch.seq_lens_cpu.max().item() - sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) cu_seqlens_q = torch.nn.functional.pad( @@ -624,6 +622,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.forward_decode_metadata.block_kv_indices = block_kv_indices self.forward_decode_metadata.max_seq_len_k = int(max_seq) + self.forward_decode_metadata.batch_size = bs forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: @@ -863,6 +862,14 @@ def forward_decode( or self.forward_decode_metadata ) + # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch + # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends + # and padding logic in prepare_mlp_sync_batch to avoid this + batch_size = getattr(metadata, "batch_size", None) + if batch_size is not None and batch_size < forward_batch.batch_size: + self.init_forward_metadata(forward_batch) + metadata = forward_batch.decode_trtllm_mla_metadata + # Scale computation for TRTLLM MLA kernel BMM1 operation: # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale # Scale components: @@ -976,6 +983,14 @@ def forward_extend( or self.forward_decode_metadata ) + # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch + # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends + # and padding logic in prepare_mlp_sync_batch to avoid this + batch_size = getattr(metadata, "batch_size", None) + if batch_size is not None and batch_size < forward_batch.batch_size: + self.init_forward_metadata(forward_batch) + metadata = forward_batch.decode_trtllm_mla_metadata + # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim] bs = forward_batch.batch_size @@ -997,27 +1012,6 @@ def forward_extend( ) else: max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q - # Check if we're in CUDA graph mode (buffers are pre-allocated) - if self.padded_q_buffer is not None: - # Use pre-allocated buffer for CUDA graph compatibility - padded_q = self.padded_q_buffer[ - :bs, : metadata.max_seq_len_q, :, : - ].to(dtype=q.dtype) - else: - # Dynamic allocation for non-CUDA graph mode - padded_q = torch.zeros( - bs, - metadata.max_seq_len_q, - layer.tp_q_head_num, - layer.head_dim, - dtype=q.dtype, - device=q.device, - ) - q = self.pad_draft_extend_query( - q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q - ) - - # TODO may use `mla_rope_quantize_fp8` fusion q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) assert kv_cache.dtype == self.data_type @@ -1034,15 +1028,6 @@ def forward_extend( bmm1_scale=bmm1_scale, ) - # Reshape output directly without slicing - - if forward_batch.forward_mode.is_draft_extend(include_v2=True): - raw_out = self.unpad_draft_extend_output( - raw_out, - metadata.cu_seqlens_q, - metadata.seq_lens_q, - metadata.sum_seq_lens_q, - ) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 85d28b148d73..70466bb20838 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -122,6 +122,7 @@ def is_auto(self) -> bool: MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None +SPECULATIVE_MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None DEEPEP_MODE: Optional[DeepEPMode] = None IS_TBO_ENABLED: Optional[bool] = None IS_SBO_ENABLED: Optional[bool] = None @@ -135,6 +136,7 @@ def initialize_moe_config(server_args: ServerArgs): global MOE_A2A_BACKEND global MOE_RUNNER_BACKEND global SPECULATIVE_MOE_RUNNER_BACKEND + global SPECULATIVE_MOE_A2A_BACKEND global DEEPEP_MODE global DEEPEP_CONFIG global IS_TBO_ENABLED @@ -150,6 +152,11 @@ def initialize_moe_config(server_args: ServerArgs): if server_args.speculative_moe_runner_backend is not None else MOE_RUNNER_BACKEND ) + SPECULATIVE_MOE_A2A_BACKEND = ( + MoeA2ABackend(server_args.speculative_moe_a2a_backend) + if server_args.speculative_moe_a2a_backend is not None + else MOE_A2A_BACKEND + ) DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) DEEPEP_CONFIG = server_args.deepep_config or "" IS_TBO_ENABLED = server_args.enable_two_batch_overlap @@ -189,6 +196,16 @@ def get_speculative_moe_runner_backend() -> MoeRunnerBackend: return SPECULATIVE_MOE_RUNNER_BACKEND +def get_speculative_moe_a2a_backend() -> MoeA2ABackend: + global SPECULATIVE_MOE_A2A_BACKEND + if SPECULATIVE_MOE_A2A_BACKEND is None: + logger.warning( + "SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend" + ) + SPECULATIVE_MOE_A2A_BACKEND = MoeA2ABackend.NONE + return SPECULATIVE_MOE_A2A_BACKEND + + def get_deepep_mode() -> DeepEPMode: global DEEPEP_MODE if DEEPEP_MODE is None: @@ -258,6 +275,21 @@ def speculative_moe_backend_context(): MOE_RUNNER_BACKEND = original_backend +@contextmanager +def speculative_moe_a2a_backend_context(): + """ + Context manager to temporarily use the speculative MoE A2A backend for draft model operations. + This ensures that draft models in speculative decoding use the configured speculative A2A backend. + """ + global MOE_A2A_BACKEND + original_backend = MOE_A2A_BACKEND + try: + MOE_A2A_BACKEND = MoeA2ABackend.NONE + yield + finally: + MOE_A2A_BACKEND = original_backend + + # The type of method in top-K routing, for use in torch custom op # Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h class RoutingMethodType(IntEnum): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8baddc56f0c5..3437f2ef869d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -764,7 +764,12 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): bs = self.batch_size - if self.forward_mode.is_decode(): + if ( + self.forward_mode.is_decode() + or self.forward_mode.is_target_verify() + or self.forward_mode.is_draft_extend(include_v2=True) + or self.forward_mode.is_idle() + ): if self.is_extend_in_batch and dp_padding_mode.is_max_len(): setattr(self, "_original_forward_mode", self.forward_mode) self.forward_mode = ForwardMode.EXTEND diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 737dd5cf10ba..a6a38507b3e5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2502,7 +2502,9 @@ def forward_normal_chunked_kv_prepare( ) def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): - has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu) + has_extend_prefix = forward_batch.extend_prefix_lens_cpu is not None and any( + forward_batch.extend_prefix_lens_cpu + ) # Only initialize the info once if has_extend_prefix and forward_batch.num_prefix_chunks is None: forward_batch.prepare_chunked_prefix_cache_info(q.device) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 69440d16a097..95577981cf5e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -167,6 +167,8 @@ "cutlass", ] +MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"] + MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"] @@ -394,6 +396,7 @@ class ServerArgs: speculative_token_map: Optional[str] = None speculative_attention_mode: str = "prefill" speculative_moe_runner_backend: Optional[str] = None + speculative_moe_a2a_backend: Optional[str] = None # Speculative decoding (ngram) speculative_ngram_min_match_window_size: int = 1 @@ -3007,6 +3010,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.speculative_moe_runner_backend, help="Choose the runner backend for MoE in speculative decoding.", ) + parser.add_argument( + "--speculative-moe-a2a-backend", + type=str, + choices=MOE_A2A_BACKEND_CHOICES, + default=ServerArgs.speculative_moe_a2a_backend, + help="Choose the backend for MoE A2A in speculative decoding", + ) # Speculative decoding (ngram) parser.add_argument( @@ -3065,7 +3075,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--moe-a2a-backend", type=str, - choices=["none", "deepep", "mooncake", "ascend_fuseep"], + choices=MOE_A2A_BACKEND_CHOICES, default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 51e7bf060856..b3d72df05cfe 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -67,6 +67,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): seq_lens_cpu: torch.Tensor grammar: BaseGrammarObject = None + # Shape info for padding + num_tokens_per_batch: int = -1 + def __post_init__(self): super().__init__(SpecInputType.EAGLE_VERIFY) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 0add45539222..07e3798f1f09 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -10,7 +10,10 @@ ) from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.utils import speculative_moe_backend_context +from sglang.srt.layers.moe.utils import ( + speculative_moe_a2a_backend_context, + speculative_moe_backend_context, +) from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -132,7 +135,9 @@ def __init__( ctx = draft_tp_context(get_attention_tp_group()) else: ctx = empty_context() - with ctx, speculative_moe_backend_context(): + with ( + ctx + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): super().__init__( server_args=server_args, gpu_id=gpu_id, @@ -183,7 +188,7 @@ def __init__( ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): self.init_attention_backend() self.init_cuda_graphs() @@ -276,7 +281,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) @@ -289,7 +294,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul else: with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): spec_info = self.draft(batch) logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) @@ -297,7 +302,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. if ( @@ -665,6 +670,7 @@ def clear_cache_pool(self): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) + spec_info.num_tokens_per_batch = self.speculative_num_steps + 1 batch.return_hidden_states = False batch.forward_mode = ( ForwardMode.TARGET_VERIFY diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 579d8f57c89d..629ddf27ef02 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,7 +12,10 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) -from sglang.srt.layers.moe.utils import speculative_moe_backend_context +from sglang.srt.layers.moe.utils import ( + speculative_moe_a2a_backend_context, + speculative_moe_backend_context, +) from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import GenerationBatchResult @@ -112,7 +115,7 @@ def __init__( self.req_to_token_pool, self.token_to_kv_pool_allocator = ( target_worker.get_memory_pool() ) - with empty_context(), speculative_moe_backend_context(): + with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): # Init draft worker self.draft_worker = TpModelWorker( server_args=server_args, @@ -140,7 +143,7 @@ def __init__( ) with self.draft_tp_context( self.draft_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): self.init_attention_backend() self.init_cuda_graphs() @@ -611,12 +614,15 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Draft prefill model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST - batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill( - model_worker_batch, - batch_output.logits_output.hidden_states, - batch_output.next_token_ids, - ) - return batch_output + with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + batch_output.next_draft_input = ( + self.draft_worker._draft_extend_for_prefill( + model_worker_batch, + batch_output.logits_output.hidden_states, + batch_output.next_token_ids, + ) + ) + return batch_output else: if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( @@ -626,11 +632,17 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) - verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch) + with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + verify_input: EagleVerifyInput = self.draft_worker.draft( + model_worker_batch + ) assert verify_input.is_verify_input() model_worker_batch.spec_info = verify_input batch_output = self.verify(model_worker_batch) - self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output) + with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + self.draft_worker._draft_extend_for_decode( + model_worker_batch, batch_output + ) return batch_output def verify(self, batch: ModelWorkerBatch): @@ -643,6 +655,7 @@ def verify(self, batch: ModelWorkerBatch): # Parse args verify_input: EagleVerifyInput = batch.spec_info + verify_input.num_tokens_per_batch = self.speculative_num_steps + 1 bs = len(batch.seq_lens) # Batch 1: Target verify diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py index 230a6ed00cd5..e1f331975bde 100644 --- a/python/sglang/srt/speculative/standalone_worker.py +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -3,7 +3,10 @@ import torch -from sglang.srt.layers.moe.utils import speculative_moe_backend_context +from sglang.srt.layers.moe.utils import ( + speculative_moe_a2a_backend_context, + speculative_moe_backend_context, +) from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -67,7 +70,7 @@ def __init__( self.hot_token_id = None # Init draft worker - with empty_context(), speculative_moe_backend_context(): + with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): TpModelWorker.__init__( self, server_args=server_args, @@ -91,7 +94,7 @@ def __init__( ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(): + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): self.init_attention_backend() self.init_cuda_graphs() diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 08a13f604f81..6fa0b2404ba0 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2797,6 +2797,8 @@ def require_mlp_tp_gather(server_args: ServerArgs): """ Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups. """ + from sglang.srt.layers.moe.utils import get_moe_a2a_backend + if server_args.enable_dp_attention: assert server_args.dp_size > 1, "dp_size must be greater than 1" if ( @@ -2805,7 +2807,7 @@ def require_mlp_tp_gather(server_args: ServerArgs): return True elif not server_args.enable_dp_lm_head: return True - elif server_args.moe_a2a_backend == "none": + elif get_moe_a2a_backend().is_none(): return True else: return ( @@ -2820,8 +2822,10 @@ def require_attn_tp_gather(server_args: ServerArgs): """ Check if the input of attention is scattered. """ + from sglang.srt.layers.moe.utils import get_moe_a2a_backend + assert server_args.moe_dense_tp_size in [1, None] - if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1: + if not get_moe_a2a_backend().is_none() or server_args.moe_dense_tp_size == 1: if server_args.enable_dp_attention: return server_args.dp_size < server_args.tp_size else: