Skip to content

[PCG] fix piecewise cuda graph for Qwen3.5#19220

Merged
ispobock merged 4 commits intosgl-project:mainfrom
zminglei:qwen3.5-pcg
Feb 26, 2026
Merged

[PCG] fix piecewise cuda graph for Qwen3.5#19220
ispobock merged 4 commits intosgl-project:mainfrom
zminglei:qwen3.5-pcg

Conversation

@zminglei
Copy link
Collaborator

@zminglei zminglei commented Feb 24, 2026

Motivation

fix piecewise cuda graph for Qwen3.5

Modifications

  1. fix piecewise cuda graph for Qwen3.5
  2. clean up legacy code gdn_with_output as it's not used anymore.

Accuracy Tests

main:

SGLANG_ENABLE_JIT_DEEPGEMM=0 python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --mem-fraction-static 0.8 --context-length 262144 --reasoning-parser qwen3 --enable-piecewise-cuda-graph

[2026-02-24 04:43:15 TP4] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/jobuser/zminglei/sglang/venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2755, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/home/jobuser/zminglei/sglang/venv/lib/python3.10/site-packages/torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
NotImplementedError: sgl_kernel::fp8_blockwise_scaled_mm: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

This PR:

SGLANG_ENABLE_JIT_DEEPGEMM=0 python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --mem-fraction-static 0.8 --context-length 262144 --reasoning-parser qwen3 --enable-piecewise-cuda-graph

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl --port 8000 --num-questions 1319 --parallel 1319
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:56<00:00, 11.35it/s]
Accuracy: 0.948
Invalid: 0.008
Latency: 116.176 s
Output throughput: 1831.956 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zminglei zminglei marked this pull request as ready for review February 24, 2026 05:51
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zminglei
Copy link
Collaborator Author

zminglei commented Feb 24, 2026

/tag-and-rerun-ci again

):
output = torch.empty_like(hidden_states)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
gdn_with_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this branch?

Copy link
Collaborator Author

@zminglei zminglei Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch was for PCG purpose (most likely copied from previous qwen3_next.py). But now it's not needed anymore since we added split_op inside RadixLinearAttention. It's for the same purpose of this PR #17613

Copy link
Collaborator

@Oasis-Git Oasis-Git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe split the change for vl model to another pr if possible

@ispobock ispobock merged commit b3202fe into sgl-project:main Feb 26, 2026
631 of 679 checks passed
klhhhhh pushed a commit to klhhhhh/sglang that referenced this pull request Feb 26, 2026
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
lawrence-harmonic added a commit to lawrence-harmonic/sglang that referenced this pull request Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants