Skip to content

[MUSA] Use MUSA-optimized operators in piecewise CUDA graph#23633

Merged
Kangyan-Zhou merged 3 commits into
sgl-project:mainfrom
popsiclexu:xzx/musa-pcg
May 12, 2026
Merged

[MUSA] Use MUSA-optimized operators in piecewise CUDA graph#23633
Kangyan-Zhou merged 3 commits into
sgl-project:mainfrom
popsiclexu:xzx/musa-pcg

Conversation

@popsiclexu
Copy link
Copy Markdown
Contributor

@popsiclexu popsiclexu commented Apr 24, 2026

Motivation

MUSA devices previously could not use piecewise CUDA graph due to multiple incompatibilities: torch.compile on MUSA fails to serialize custom device types in FX-generated code, and several operators (SiluAndMul, RMSNorm, FP8 quantization) lacked fake kernel registrations needed for torch.compile tracing. As a workaround, these operators fell back to native PyTorch implementations when piecewise CUDA graph was enabled, resulting in suboptimal performance. This PR enables piecewise CUDA graph to work correctly on MUSA devices and allows the use of MUSA-optimized operators (e.g., nn.SwishGLU).

Modifications

  1. Patch FX codegen for custom devices (patch_torch.py): Add patch_fx_custom_device() that post-processes FX-generated source code, replacing device(type='musa', index=N) with torch.device('musa:N'). This fixes torch.compile serialization failure on MUSA, which was the root cause of piecewise CUDA graph crashes.

  2. Register fake kernels for torch.compile tracing:

    • activation.py: Register aten::_fused_swiglu_forward fake kernel so nn.SwishGLU can be traced.
    • fp8_kernel.py: Register sgl_kernel::sgl_per_token_group_quant_8bit_v2 fake kernel so FP8 quantization can be traced.
  3. Remove piecewise CUDA graph workarounds in MUSA operators (activation.py, layernorm.py): Remove the if not get_global_server_args().disable_piecewise_cuda_graph: return self.forward_native(x) guards in SiluAndMul.forward_musa and RMSNorm.forward_musa, so MUSA-optimized code paths are used even when piecewise CUDA graph is enabled.

Accuracy Tests

run server:

python3 -m sglang.launch_server \
  --model-path "/home/dist/DeepSeek-V2-Lite-Chat-FP8/ \
  --served-model-name "deepseek" \
  --mem-fraction-static 0.5 \
  --disable-overlap-schedule \
  --attention-backend fa3 \
  --cuda-graph-bs $(seq 1 16) \
  --chunked-prefill-size -1 \
  --disable-radix-cache  \
  --piecewise-cuda-graph-max-tokens 2048
python3 -m sglang.test.few_shot_gsm8k   --host http://127.0.0.1   --port 30000   --num-questions 200   --parallel 16
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:06<00:00,  3.00it/s]
Accuracy: 0.650
Invalid: 0.005
Latency: 66.860 s
Output throughput: 381.153 token/s

Known Limitations

Current piecewise CUDA graph does not support --moe-runner-backend=deepgemm on MUSA. We plan to add support in a follow-up PR.

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements MUSA device support, focusing on piecewise CUDA graph execution and kernel registrations. Key updates include a torch.fx patch for MUSA device serialization, fake kernel registrations for activation and quantization layers, and the removal of redundant CUDA graph checks. Feedback highlights improvements for robustness, such as using _replace for NamedTuple objects, safely accessing torch.version attributes, and employing conditional fake registrations for kernels.

Comment thread python/sglang/srt/utils/patch_torch.py Outdated
Comment thread sgl-kernel/python/sgl_kernel/utils.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp8_kernel.py Outdated
@popsiclexu popsiclexu force-pushed the xzx/musa-pcg branch 2 times, most recently from 50310ee to e357993 Compare May 6, 2026 02:29
popsiclexu pushed a commit to popsiclexu/sglang that referenced this pull request May 8, 2026
- Use NamedTuple._replace() instead of positional PythonCode() constructor
  in patch_fx_custom_device for robustness across PyTorch versions
- Use getattr() for torch.version.musa/hip to avoid AttributeError
- Use register_fake_if_exists() instead of torch.library.register_fake()
  for safe conditional fake kernel registration

Signed-off-by: popsiclexu <zhenxue.xu@mthreads.com>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: popsiclexu <zhenxuexu@gmail.com>
@popsiclexu popsiclexu marked this pull request as ready for review May 8, 2026 03:40
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

popsiclexu pushed a commit to popsiclexu/sglang that referenced this pull request May 8, 2026
- Use NamedTuple._replace() instead of positional PythonCode() constructor
  in patch_fx_custom_device for robustness across PyTorch versions
- Use getattr() for torch.version.musa/hip to avoid AttributeError
- Use register_fake_if_exists() instead of torch.library.register_fake()
  for safe conditional fake kernel registration

Signed-off-by: popsiclexu <zhenxue.xu@mthreads.com>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: popsiclexu <zhenxuexu@gmail.com>
@popsiclexu popsiclexu force-pushed the xzx/musa-pcg branch 2 times, most recently from 6bc03b8 to 1e4066a Compare May 8, 2026 05:00
Comment thread python/sglang/srt/utils/patch_torch.py Outdated
Comment thread python/sglang/srt/layers/activation.py
@github-actions github-actions Bot added dependencies Pull requests that update a dependency file mthreads labels May 11, 2026
Signed-off-by: popsiclexu <zhenxuexu@gmail.com>
@yeahdongcn
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Kangyan-Zhou Kangyan-Zhou merged commit bfc2eda into sgl-project:main May 12, 2026
181 of 210 checks passed
LucQueen pushed a commit to LucQueen/sglang that referenced this pull request May 12, 2026
SpencerGarnets added a commit to ai-blaise/optimization-playground that referenced this pull request May 12, 2026
…ack)

Brings in upstream sgl-project/sglang main commits since
096ad02 (merge base, Laguna-XS.2 model support).
Total: 28 upstream commits composed.

Custom-stack files preserved intact (entirely-ours, byte-identical to
origin/main):
  - Blackwell CuTe kernel suite (warp_decode_cute, g1_attention_cute,
    gated_norm_cute, layersplit_cute, fused_store_index_cache)
  - TurboQuant 2.5-bit dense KV cache path
  - HIGGS 2-bit dense KV cache path (with split-K decode)
  - NVFP4 IndexCache dispatcher (active gate)
  - quantization_config_dispatch (HF-config-driven runtime routing)
  - All custom server-args flags and runtime methods preserved

Verification:
  - 200+ merged Python files compile cleanly
  - Dispatcher symbol presence verified
  - HIGGS pool / TurboQuant pool classes present at expected lines
  - compressed_tensors_w4a4_nvfp4_moe imports clean
  - All custom server-args flags present (enable_higgs_dense_2bit_kv_cache,
    enable_turboquant_dense_kv_cache, turboquant_dense_kv_preset,
    indexer_quantization_declared, higgs_mla_decode_num_splits, etc.)

Manual-merged shared files (auto-merge gave broken/mixed output; cleaned
up post-merge):
  - python/sglang/srt/disaggregation/mooncake/conn.py: upstream's PR#24932
    refactored maybe_send_extra into a state-types-loop. Replayed our
    LayerSplit NSA state-index-length-mismatch check inside the SWA/NSA
    branch of the new loop body.
  - sgl-kernel/python/sgl_kernel/__init__.py: upstream's PR#23449 (Apple
    Silicon Metal kernel) wrapped the entire module body in
    `if darwin/arm64: from sgl_kernel.metal import * else: ...`. The
    auto-merge duplicated the file body; rewrote cleanly with upstream's
    structure and re-injected our `g1_gate_forward`,
    `warp_decode_cute_moe_forward`, and
    `warp_decode_cute_moe_packed_forward` imports plus `g1_gate_forward`
    in _DEBUG_EXPORT_NAMES.
  - python/sglang/srt/managers/scheduler_output_processor_mixin.py: line
    628 still referenced `result.num_accepted_drafts` (renamed by PR
    sgl-project#25038 to `num_correct_drafts`). Renamed in place.
  - python/sglang/srt/observability/scheduler_metrics_mixin.py: a block
    around the spec-decode logging path had mixed old/new names from
    auto-merge (lines 553/557/560). Renamed `spec_num_accepted_tokens`
    -> `spec_num_accept_tokens` and local `num_accepted_drafts` ->
    `num_correct_drafts` to match the rest of the file.
  - test/test_smc_info.py: stub Req mock used the old field names
    `spec_accepted_drafts` and `update_spec_acceptance_histogram`.
    Renamed to `spec_num_correct_drafts` and
    `update_spec_correct_drafts_histogram` per PR sgl-project#24081.

Auto-merge cleanly integrated upstream changes to:
  - server_args.py (new fields: prefill_only_disable_kv_cache,
    weight_loader_drop_cache_after_load, prefill_delayer_queue_min_ratio,
    prefill_delayer_max_delay_ms, speculative_draft_window_size, etc.)
  - mem_cache/memory_pool.py (new NoOpMHATokenToKVPool)
  - model_executor/model_runner_kv_cache_mixin.py (NoOpMHATokenToKVPool
    pool factory + _validate_prefill_only_disable_kv_cache_pool_family)
  - layers/attention/nsa_backend.py (spec rename
    num_accepted_drafts -> num_correct_drafts;
    num_accepted_tokens -> num_accept_tokens)
  - layers/attention/nsa/nsa_indexer.py (new _apply_q_scale_and_softmax_scale
    compile method; torch.mm replaces deep_gemm wrapper)
  - 28+ disaggregation/spec/runner files with mostly clean
    upstream-side-only integration.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

----- upstream commit subjects (28) -----
fd3eb77 [Cookbook]: add Laguna-XS.2 (Poolside) (sgl-project#24730)
6be1a45 Fix swa component host hit (sgl-project#25085)
693f497 [NPU] use causal_conv1d_update_v2 for performance (sgl-project#24595)
1efe9e2 [Bug Fix] Reject incompatible combination of --disable-cuda-graph-padding and --enable-torch-compile (sgl-project#23903)
8d27ce7 Optimize uvicorn startup command (sgl-project#25041)
b35fd5f [fix] skip legacy minicpmv conv template for MiniCPM-V 4.6 (sgl-project#24998)
7582237 [Tiny Fix] Disable BCG when inner layer_model unresolved (sgl-project#25021)
ca3bc05 Deepseek-v4-Pro share expert tp1 (sgl-project#24949)
a72d3ae [Spec] Multi-layer mamba scatter cleanup; fix positional call bug (sgl-project#25030)
7128533 Revert "Migrate Intel CPU cases to the test/registered." (sgl-project#25044)
1f985c5 [Spec] Rename `accepted_indices` -> `accept_indices`; drop `_token_id` suffix per Rule 5 (sgl-project#25038)
ecf5d84 Migrate Intel CPU cases to the test/registered. (sgl-project#22670)
d7f4761 [PD] Refactor hybrid state transfer (sgl-project#24932)
91907b7 [UnifiedTree]: Fix Unified HiCache tombstone lock release replay (sgl-project#24972)
4ad63ad [Spec] Rename `accepted_drafts` -> `correct_drafts` for unambiguous naming (sgl-project#24081)
6bfb365 [PD] Rate limit prefill inflight polling warnings (sgl-project#24967)
6bb79c1 [Linear Attn] Add CUSTOM enum and plugin extensibility for kernel backends (sgl-project#24937)
cfc41d5 Fix kimi k2.5 mla eagle + dp attention (sgl-project#25033)
0f3932c [Fix] Qwen3-ASR config: set thinker_config before super().__init__ (sgl-project#24187)
f526e3f [Spec] Mamba scatter cleanup; fix multi-layer positional bug; dflash naming (sgl-project#25029)
10375a1 [NIXL][XPU] Fix uint64 overflow for mismatched P/D TP sizes (e.g. prefill_tp=1, decode_tp=2) (sgl-project#24648)
0a37d24 [diffusion] hardware: support sage attention backend on MUSA (attn backend, 21/N) (sgl-project#24752)
5495026 [HiCache] feat: default storage prefetch timeout (sgl-project#23309)
186eb42 Feat: Support SWA (Sliding Window Attention) for EAGLE-3 drafter (sgl-project#24664)
a75b79e Feat: Support newer EAGLE-3 drafters (sgl-project#24663)
f3a8189 [Spec] Internal rename per N2 v2 naming rule (sgl-project#25014)
bfc2eda [MUSA] Use MUSA-optimized operators in piecewise CUDA graph (sgl-project#23633)
74d70af [Apple Silicon] Add Metal kernel support in sgl-kernel (sgl-project#23449)
xjpang pushed a commit to xjpang/sglang that referenced this pull request May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file mthreads run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants