Skip to content

[MUSA][16/N] Add MUSA backend support for layers and DeepSeek models (V2/V3/R1)#22774

Merged
Kangyan-Zhou merged 2 commits intosgl-project:mainfrom
popsiclexu:xzx/deepseek
Apr 24, 2026
Merged

[MUSA][16/N] Add MUSA backend support for layers and DeepSeek models (V2/V3/R1)#22774
Kangyan-Zhou merged 2 commits intosgl-project:mainfrom
popsiclexu:xzx/deepseek

Conversation

@popsiclexu
Copy link
Copy Markdown
Contributor

@popsiclexu popsiclexu commented Apr 14, 2026

Motivation

Enable SGLang to run DeepSeek models on Moore Threads MUSA GPUs. This PR adds MUSA backend support across the inference stack, including layers, quantization, MoE, attention, speculative decoding, and custom op registration.

Modifications

Core Infrastructure:

  • Custom op registration (utils/common.py): Register custom ops on the MUSA dispatch key in direct_register_custom_op; extend get_device_sm() to detect MUSA SM versions via torch.cuda.get_device_capability() (works through torchada).
  • Server args (server_args.py): Add musa as a recognized device in the CLI help string.

Layer Support:

  • Activation (activation.py): Add forward_musa for SiluAndMul — uses forward_native under piecewise CUDA graph, otherwise uses lazily-initialized nn.SwishGLU for better MUSA performance.
  • LayerNorm (layernorm.py): Add forward_musa for RMSNorm with fused add_rmsnorm and native nn.functional.rms_norm fallback; enable flashinfer layernorm import on MUSA.
  • Quantization (fp8.py, fp8_kernel.py, fp8_utils.py, unquant.py):
    • Set FP8 min capability to SM 31 on MUSA.
    • Import sgl_per_token_quant_fp8 from sgl_kernel on MUSA; force v2 per-token group quant path by default.
    • Make FP8 matmul scale tensors contiguous for DeepGEMM on MUSA (required by deep_gemm_fp8_fp8_bf16_nt).
    • Add forward_musa to UnquantizedFusedMoEMethod delegating to forward_cuda.
  • Sampler (sampler.py): Reformat import block (no functional change).

MoE:

  • Fused MoE (fused_moe.py, moe_align_block_size.py): Import moe_sum_reduce on MUSA; use pre-instantiated nn.SwishGLU/nn.GELU for activation; skip pre-allocating intermediate_cache2 on MUSA; enable moe_sum_reduce for combine; import moe_align_block_size on MUSA.
  • Triton MoE runner (moe_runner/triton.py): Use pre-instantiated nn.SwishGLU/nn.GELU for activation; skip pre-allocating intermediate_cache2 on MUSA.
  • DeepGEMM MoE runner (moe_runner/deep_gemm.py): Use DEEPGEMM_NEED_TMA_ALIGNED_SCALES flag to skip TMA alignment on MUSA; skip get_mn_major_tma_aligned_tensor for scale tensors on MUSA.
  • EP MoE (ep_moe/kernels.py): Add ATOMIC_ADD_SEM="relaxed" for Triton tl.atomic_add on MUSA for better performance; enable per_token_group_quant_fp8 import on MUSA.
  • TopK (topk.py): Import moe_fused_gate and topk_softmax from sgl_kernel on MUSA; enable biased_grouped_topk_gpu path.

DeepSeek Model:

  • Weight loader (deepseek_weight_loader.py): Enable FP8 block quant weight loading on MUSA; dequantize w_kc/w_vc to bfloat16 at load time (temporary until FP8 bmm kernel is available on MUSA).
  • Attention (forward_mha.py, forward_mla.py): Enable MHA fused QK-NOPE path on MUSA; add MLA torch.bmm path for w_vc projection on MUSA.
  • Model (deepseek_v2.py): Import dsv3_fused_a_gemm and dsv3_router_gemm from sgl_kernel; enable alt_stream, shared expert fusion (SM >= 31), and routed_scaling_factor handling on MUSA.
  • Utils (utils.py): Export _is_musa flag for deepseek common modules.

Speculative Decoding:

  • Eagle (eagle_utils.py, eagle_worker.py): Import build_tree_kernel_efficient and verify_tree_greedy on MUSA; disable torch.compile on MUSA.
  • Spec utils (spec_utils.py): Refactor Triton boolean expression to avoid chained or operators (MUSA Triton limitation).

Other:

  • Observability (scheduler_metrics_mixin.py): Add musa graph label.
  • Environ (environ.py): Add SGLANG_DEEPGEMM_SANITY_CHECK env var.
  • DeepGEMM wrapper (configurer.py, entrypoint.py, compile_utils.py): Enable DeepGEMM on MUSA SM >= 31; skip JIT compilation hook on MUSA (returns nullcontext()); introduce DEEPGEMM_NEED_TMA_ALIGNED_SCALES flag; migrate sanity check flag to centralized envs config.

Accuracy Tests

Note: The accuracy tests below are run with --disable-piecewise-cuda-graph because some MUSA-optimized kernels cannot currently adapt to the piecewise CUDA graph mechanism. We will gradually
fix this to support piecewise CUDA graphs on MUSA in future updates.

DeepSeek-Coder-V2-Lite-Instruct

run command:

python3 -m sglang.launch_server \
  --model-path "$MODEL_PATH" \
  --served-model-name "$SERVED_MODEL_NAME" \
  --mem-fraction-static 0.75 \
  --disable-overlap-schedule \
  --host "$HOST" \
  --port "$PORT" \
  --attention-backend fa3 \
  --cuda-graph-bs $(seq 1 2 64) \
  --tp-size 1 \
  --chunked-prefill-size -1 \
  --max-prefill-tokens 8192 \
  --disable-piecewise-cuda-graph

result:

python3 -m sglang.test.few_shot_gsm8k   --host http://127.0.0.1   --port 30000   --num-questions 200  --parallel 64
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [07:18<00:00,  2.19s/it]
Accuracy: 0.820
Invalid: 0.000
Latency: 442.925 s
Output throughput: 58.832 token/s

DeepSeek-V2-Lite-Chat-FP8

run command:

python3 -m sglang.launch_server \
  --model-path "$MODEL_PATH" \
  --served-model-name "$SERVED_MODEL_NAME" \
  --mem-fraction-static 0.75 \
  --disable-overlap-schedule \
  --host "$HOST" \
  --port "$PORT" \
  --attention-backend fa3 \
  --cuda-graph-bs $(seq 1 2 64) \
  --tp-size 1 \
  --chunked-prefill-size -1 \
  --max-prefill-tokens 8192 \
  --moe-runner-backend "deep_gemm" \
  --disable-piecewise-cuda-graph

result:

python3 -m sglang.test.few_shot_gsm8k   --host http://127.0.0.1   --port 30000   --num-questions 200 --parallel 64                                           
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:38<00:00,  5.22it/s]
Accuracy: 0.665
Invalid: 0.000
Latency: 38.459 s
Output throughput: 647.187 token/s

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@popsiclexu
Copy link
Copy Markdown
Contributor Author

/label mthreads

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for the MUSA (Moore Threads) architecture throughout the SGLang runtime, including updates to activation functions, layer normalization, MoE runners, and DeepSeek model implementations. The review feedback highlights a logic error in a platform check within the TopK selection and identifies several opportunities to improve performance by caching activation module instances instead of re-instantiating them during forward passes.

Comment thread python/sglang/srt/layers/moe/topk.py Outdated
Comment thread python/sglang/srt/layers/activation.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/triton.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/triton.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/triton.py Outdated
@popsiclexu popsiclexu marked this pull request as draft April 14, 2026 08:20
@popsiclexu popsiclexu force-pushed the xzx/deepseek branch 3 times, most recently from 6403f6f to 1b9f2fa Compare April 14, 2026 10:12
Comment thread python/sglang/srt/layers/moe/moe_runner/deep_gemm.py Outdated
Comment thread python/sglang/srt/layers/moe/token_dispatcher/deepep.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp8_kernel.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp8_kernel.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp8_kernel.py
Comment thread python/sglang/srt/model_executor/model_runner.py Outdated
Comment thread python/sglang/srt/speculative/eagle_info.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/layers/moe/ep_moe/kernels.py
@yeahdongcn yeahdongcn changed the title [MUSA][16/N] Add MUSA backend support for layers and Deepseek models (V2/V3/R1) [MUSA][16/N] Add MUSA backend support for layers and DeepSeek models (V2/V3/R1) Apr 14, 2026
Comment thread python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
Comment thread python/sglang/srt/layers/deep_gemm_wrapper/configurer.py Outdated
Comment thread python/sglang/srt/layers/deep_gemm_wrapper/configurer.py Outdated
Comment thread python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py
Comment thread python/sglang/srt/layers/moe/moe_runner/deep_gemm.py Outdated
Comment thread python/sglang/srt/layers/activation.py Outdated
Comment thread python/sglang/srt/utils/common.py Outdated
@popsiclexu popsiclexu marked this pull request as ready for review April 15, 2026 11:26
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yeahdongcn yeahdongcn requested a review from alexnails April 15, 2026 11:29
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

3 similar comments
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

2 similar comments
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@popsiclexu
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

popsiclexu and others added 2 commits April 22, 2026 10:51
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

2 similar comments
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yeahdongcn
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@Kangyan-Zhou Kangyan-Zhou merged commit b35213b into sgl-project:main Apr 24, 2026
463 of 562 checks passed
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants