diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index b1e0f1ebb314..891cbacfaaf4 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -259,8 +259,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float | | `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str | | `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` | -| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | `None` | | -| `--speculative-moe-a2a-backend` | MOE A2A backend for EAGLE speculative decoding, see --moe-a2a-backend for options. Same as moe a2a backend if unset. | `None` | | +| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None | ## Ngram speculative decoding | Argument | Description | Defaults | Options | @@ -277,7 +276,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Argument | Description | Defaults | Options | | --- | --- | --- | --- | | `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `ascend_fuseep`| +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep` | | `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | | `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9740095a5e3f..b409cfdbf4ae 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -594,6 +594,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens = seq_lens + self.num_draft_tokens self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32) elif forward_batch.forward_mode.is_draft_extend(include_v2=True): + max_seq = forward_batch.seq_lens_cpu.max().item() + sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) cu_seqlens_q = torch.nn.functional.pad( @@ -622,7 +624,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.forward_decode_metadata.block_kv_indices = block_kv_indices self.forward_decode_metadata.max_seq_len_k = int(max_seq) - self.forward_decode_metadata.batch_size = bs forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: @@ -862,14 +863,6 @@ def forward_decode( or self.forward_decode_metadata ) - # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch - # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends - # and padding logic in prepare_mlp_sync_batch to avoid this - batch_size = getattr(metadata, "batch_size", None) - if batch_size is not None and batch_size < forward_batch.batch_size: - self.init_forward_metadata(forward_batch) - metadata = forward_batch.decode_trtllm_mla_metadata - # Scale computation for TRTLLM MLA kernel BMM1 operation: # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale # Scale components: @@ -983,14 +976,6 @@ def forward_extend( or self.forward_decode_metadata ) - # Ensure batch_size is sufficient, the batch size increase due to the padding from the forward batch - # FIXME(@rainj-me), refactor the skip_attn_backend_init, init_forward_metadata for attn backends - # and padding logic in prepare_mlp_sync_batch to avoid this - batch_size = getattr(metadata, "batch_size", None) - if batch_size is not None and batch_size < forward_batch.batch_size: - self.init_forward_metadata(forward_batch) - metadata = forward_batch.decode_trtllm_mla_metadata - # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim] bs = forward_batch.batch_size @@ -1012,6 +997,27 @@ def forward_extend( ) else: max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q + # Check if we're in CUDA graph mode (buffers are pre-allocated) + if self.padded_q_buffer is not None: + # Use pre-allocated buffer for CUDA graph compatibility + padded_q = self.padded_q_buffer[ + :bs, : metadata.max_seq_len_q, :, : + ].to(dtype=q.dtype) + else: + # Dynamic allocation for non-CUDA graph mode + padded_q = torch.zeros( + bs, + metadata.max_seq_len_q, + layer.tp_q_head_num, + layer.head_dim, + dtype=q.dtype, + device=q.device, + ) + q = self.pad_draft_extend_query( + q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q + ) + + # TODO may use `mla_rope_quantize_fp8` fusion q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) assert kv_cache.dtype == self.data_type @@ -1028,6 +1034,15 @@ def forward_extend( bmm1_scale=bmm1_scale, ) + # Reshape output directly without slicing + + if forward_batch.forward_mode.is_draft_extend(include_v2=True): + raw_out = self.unpad_draft_extend_output( + raw_out, + metadata.cu_seqlens_q, + metadata.seq_lens_q, + metadata.sum_seq_lens_q, + ) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 70466bb20838..85d28b148d73 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -122,7 +122,6 @@ def is_auto(self) -> bool: MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None -SPECULATIVE_MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None DEEPEP_MODE: Optional[DeepEPMode] = None IS_TBO_ENABLED: Optional[bool] = None IS_SBO_ENABLED: Optional[bool] = None @@ -136,7 +135,6 @@ def initialize_moe_config(server_args: ServerArgs): global MOE_A2A_BACKEND global MOE_RUNNER_BACKEND global SPECULATIVE_MOE_RUNNER_BACKEND - global SPECULATIVE_MOE_A2A_BACKEND global DEEPEP_MODE global DEEPEP_CONFIG global IS_TBO_ENABLED @@ -152,11 +150,6 @@ def initialize_moe_config(server_args: ServerArgs): if server_args.speculative_moe_runner_backend is not None else MOE_RUNNER_BACKEND ) - SPECULATIVE_MOE_A2A_BACKEND = ( - MoeA2ABackend(server_args.speculative_moe_a2a_backend) - if server_args.speculative_moe_a2a_backend is not None - else MOE_A2A_BACKEND - ) DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) DEEPEP_CONFIG = server_args.deepep_config or "" IS_TBO_ENABLED = server_args.enable_two_batch_overlap @@ -196,16 +189,6 @@ def get_speculative_moe_runner_backend() -> MoeRunnerBackend: return SPECULATIVE_MOE_RUNNER_BACKEND -def get_speculative_moe_a2a_backend() -> MoeA2ABackend: - global SPECULATIVE_MOE_A2A_BACKEND - if SPECULATIVE_MOE_A2A_BACKEND is None: - logger.warning( - "SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend" - ) - SPECULATIVE_MOE_A2A_BACKEND = MoeA2ABackend.NONE - return SPECULATIVE_MOE_A2A_BACKEND - - def get_deepep_mode() -> DeepEPMode: global DEEPEP_MODE if DEEPEP_MODE is None: @@ -275,21 +258,6 @@ def speculative_moe_backend_context(): MOE_RUNNER_BACKEND = original_backend -@contextmanager -def speculative_moe_a2a_backend_context(): - """ - Context manager to temporarily use the speculative MoE A2A backend for draft model operations. - This ensures that draft models in speculative decoding use the configured speculative A2A backend. - """ - global MOE_A2A_BACKEND - original_backend = MOE_A2A_BACKEND - try: - MOE_A2A_BACKEND = MoeA2ABackend.NONE - yield - finally: - MOE_A2A_BACKEND = original_backend - - # The type of method in top-K routing, for use in torch custom op # Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h class RoutingMethodType(IntEnum): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d4ed92d0ec2d..4de8a4ef7df8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -772,12 +772,7 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): bs = self.batch_size - if ( - self.forward_mode.is_decode() - or self.forward_mode.is_target_verify() - or self.forward_mode.is_draft_extend(include_v2=True) - or self.forward_mode.is_idle() - ): + if self.forward_mode.is_decode(): if self.is_extend_in_batch and dp_padding_mode.is_max_len(): setattr(self, "_original_forward_mode", self.forward_mode) self.forward_mode = ForwardMode.EXTEND diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ea8ec92a51d9..9f866499d987 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2520,9 +2520,7 @@ def forward_normal_chunked_kv_prepare( ) def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): - has_extend_prefix = forward_batch.extend_prefix_lens_cpu is not None and any( - forward_batch.extend_prefix_lens_cpu - ) + has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu) # Only initialize the info once if has_extend_prefix and forward_batch.num_prefix_chunks is None: forward_batch.prepare_chunked_prefix_cache_info(q.device) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68bd6709cb52..f95053206eb8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -168,8 +168,6 @@ "cutlass", ] -MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep"] - MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"] @@ -397,7 +395,6 @@ class ServerArgs: speculative_token_map: Optional[str] = None speculative_attention_mode: str = "prefill" speculative_moe_runner_backend: Optional[str] = None - speculative_moe_a2a_backend: Optional[str] = None # Speculative decoding (ngram) speculative_ngram_min_match_window_size: int = 1 @@ -3039,13 +3036,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.speculative_moe_runner_backend, help="Choose the runner backend for MoE in speculative decoding.", ) - parser.add_argument( - "--speculative-moe-a2a-backend", - type=str, - choices=MOE_A2A_BACKEND_CHOICES, - default=ServerArgs.speculative_moe_a2a_backend, - help="Choose the backend for MoE A2A in speculative decoding", - ) # Speculative decoding (ngram) parser.add_argument( @@ -3104,7 +3094,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--moe-a2a-backend", type=str, - choices=MOE_A2A_BACKEND_CHOICES, + choices=["none", "deepep", "mooncake", "ascend_fuseep"], default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index b3d72df05cfe..51e7bf060856 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -67,9 +67,6 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): seq_lens_cpu: torch.Tensor grammar: BaseGrammarObject = None - # Shape info for padding - num_tokens_per_batch: int = -1 - def __post_init__(self): super().__init__(SpecInputType.EAGLE_VERIFY) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 07e3798f1f09..0add45539222 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -10,10 +10,7 @@ ) from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.utils import ( - speculative_moe_a2a_backend_context, - speculative_moe_backend_context, -) +from sglang.srt.layers.moe.utils import speculative_moe_backend_context from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -135,9 +132,7 @@ def __init__( ctx = draft_tp_context(get_attention_tp_group()) else: ctx = empty_context() - with ( - ctx - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ctx, speculative_moe_backend_context(): super().__init__( server_args=server_args, gpu_id=gpu_id, @@ -188,7 +183,7 @@ def __init__( ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): self.init_attention_backend() self.init_cuda_graphs() @@ -281,7 +276,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) @@ -294,7 +289,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul else: with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): spec_info = self.draft(batch) logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) @@ -302,7 +297,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. if ( @@ -670,7 +665,6 @@ def clear_cache_pool(self): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) - spec_info.num_tokens_per_batch = self.speculative_num_steps + 1 batch.return_hidden_states = False batch.forward_mode = ( ForwardMode.TARGET_VERIFY diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 629ddf27ef02..579d8f57c89d 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,10 +12,7 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) -from sglang.srt.layers.moe.utils import ( - speculative_moe_a2a_backend_context, - speculative_moe_backend_context, -) +from sglang.srt.layers.moe.utils import speculative_moe_backend_context from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import GenerationBatchResult @@ -115,7 +112,7 @@ def __init__( self.req_to_token_pool, self.token_to_kv_pool_allocator = ( target_worker.get_memory_pool() ) - with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with empty_context(), speculative_moe_backend_context(): # Init draft worker self.draft_worker = TpModelWorker( server_args=server_args, @@ -143,7 +140,7 @@ def __init__( ) with self.draft_tp_context( self.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): self.init_attention_backend() self.init_cuda_graphs() @@ -614,15 +611,12 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Draft prefill model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST - with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - batch_output.next_draft_input = ( - self.draft_worker._draft_extend_for_prefill( - model_worker_batch, - batch_output.logits_output.hidden_states, - batch_output.next_token_ids, - ) - ) - return batch_output + batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill( + model_worker_batch, + batch_output.logits_output.hidden_states, + batch_output.next_token_ids, + ) + return batch_output else: if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( @@ -632,17 +626,11 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) - with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - verify_input: EagleVerifyInput = self.draft_worker.draft( - model_worker_batch - ) + verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch) assert verify_input.is_verify_input() model_worker_batch.spec_info = verify_input batch_output = self.verify(model_worker_batch) - with speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - self.draft_worker._draft_extend_for_decode( - model_worker_batch, batch_output - ) + self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output) return batch_output def verify(self, batch: ModelWorkerBatch): @@ -655,7 +643,6 @@ def verify(self, batch: ModelWorkerBatch): # Parse args verify_input: EagleVerifyInput = batch.spec_info - verify_input.num_tokens_per_batch = self.speculative_num_steps + 1 bs = len(batch.seq_lens) # Batch 1: Target verify diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py index e1f331975bde..230a6ed00cd5 100644 --- a/python/sglang/srt/speculative/standalone_worker.py +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -3,10 +3,7 @@ import torch -from sglang.srt.layers.moe.utils import ( - speculative_moe_a2a_backend_context, - speculative_moe_backend_context, -) +from sglang.srt.layers.moe.utils import speculative_moe_backend_context from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -70,7 +67,7 @@ def __init__( self.hot_token_id = None # Init draft worker - with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with empty_context(), speculative_moe_backend_context(): TpModelWorker.__init__( self, server_args=server_args, @@ -94,7 +91,7 @@ def __init__( ) with self.draft_tp_context( self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + ), speculative_moe_backend_context(): self.init_attention_backend() self.init_cuda_graphs() diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index cf7264225dc6..78d40064f3c1 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2799,8 +2799,6 @@ def require_mlp_tp_gather(server_args: ServerArgs): """ Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups. """ - from sglang.srt.layers.moe.utils import get_moe_a2a_backend - if server_args.enable_dp_attention: assert server_args.dp_size > 1, "dp_size must be greater than 1" if ( @@ -2809,7 +2807,7 @@ def require_mlp_tp_gather(server_args: ServerArgs): return True elif not server_args.enable_dp_lm_head: return True - elif get_moe_a2a_backend().is_none(): + elif server_args.moe_a2a_backend == "none": return True else: return ( @@ -2824,10 +2822,8 @@ def require_attn_tp_gather(server_args: ServerArgs): """ Check if the input of attention is scattered. """ - from sglang.srt.layers.moe.utils import get_moe_a2a_backend - assert server_args.moe_dense_tp_size in [1, None] - if not get_moe_a2a_backend().is_none() or server_args.moe_dense_tp_size == 1: + if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1: if server_args.enable_dp_attention: return server_args.dp_size < server_args.tp_size else: