Skip to content

Make GLM4 MoE dual stream capture-aware and configurable (restore non-DeepEP overlap)#13778

Open
voipmonitor wants to merge 4 commits intosgl-project:mainfrom
voipmonitor:fix/glm4-dual-stream-config
Open

Make GLM4 MoE dual stream capture-aware and configurable (restore non-DeepEP overlap)#13778
voipmonitor wants to merge 4 commits intosgl-project:mainfrom
voipmonitor:fix/glm4-dual-stream-config

Conversation

@voipmonitor
Copy link
Contributor

@voipmonitor voipmonitor commented Nov 23, 2025

Summary:

Why:

Testing:

  • Local 4× RTX (FP8, CUDA graph enabled) with default auto: throughput restored to pre-[Grammar Fix] GLM-4-MOE self.first_k_dense_replace is undefined. #12455 levels.

  • Manual toggles:

    • SGLANG_GLM4_MOE_DUAL_STREAM=never → overlap disabled (expected slowdown).
    • SGLANG_GLM4_MOE_DUAL_STREAM=always → overlap forced even without capture.
    • SGLANG_GLM4_MOE_DUAL_STREAM_THRESHOLD=2048 → higher threshold in non-capture auto mode.
  • inclcuded optimised triton .json files for 4x RTX 6000 PRO

testing with:

NCCL_P2P_LEVEL=4 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
USE_TRITON_W8A8_FP8_KERNEL=1 \
SGLANG_ENABLE_JIT_DEEPGEMM=0 \
python -m sglang.launch_server --model /mnt/GLM-4.6-FP8/ --tp 4 --host 0.0.0.0 --port  4999 \
 --mem-fraction-static 0.96 --context-length 200000  --enable-metrics \
 --attention-backend flashinfer   --tool-call-parser glm45    --reasoning-parser glm \
 --served-model-name glm-4.6-FP8   --chunked-prefill-size 8092 --enable-mixed-chunk  \
 --cuda-graph-max-bs 16   --kv-cache-dtype fp8_e5m2  
 --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'

Before PR: ~52 tokens/sec
After PR: ~58 tokens/sec (same speed like it was before #12455 )

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @voipmonitor, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a performance regression in GLM4 Mixture-of-Experts (MoE) models by re-introducing and enhancing the dual-stream processing capability. The changes ensure that the model can leverage parallel execution for improved throughput, especially when CUDA graph capture is active, and provide users with granular control over this behavior through new configuration options. This brings GLM4 MoE's performance and configurability in line with best practices observed in other models like DeepSeek.

Highlights

  • Dual-Stream MoE Path Restoration: The non-DeepEP dual-stream Mixture-of-Experts (MoE) path, which was inadvertently removed in a previous pull request ([Grammar Fix] GLM-4-MOE self.first_k_dense_replace is undefined. #12455), has been restored to address inference throughput regressions, particularly for small to medium micro-batches when using CUDA graphs.
  • Configurable Dual-Stream Behavior: The dual-stream behavior for GLM4 MoE is now configurable via environment variables: SGLANG_GLM4_MOE_DUAL_STREAM (with options like 'auto', 'capture', 'always', 'never') and SGLANG_GLM4_MOE_DUAL_STREAM_THRESHOLD.
  • Aligned Default Behavior with DeepSeek: The default dual-stream policy now mirrors the DeepSeek change (Use dual stream for DS MoE whenever cuda graph is used (instead of with token threshold) #9405), automatically enabling dual-stream when CUDA graph capture is active. Otherwise, it falls back to using a configurable token threshold to decide whether to activate dual-stream.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request restores the dual-stream MoE path for GLM-4 models to improve performance, making it configurable through environment variables. The changes are well-structured. However, I've found a critical issue in the dual-stream implementation that prevents the intended computation overlap, effectively serializing the operations. I've also suggested a small improvement for robustness in handling the configuration options. Addressing the critical issue is essential to achieve the performance goals of this PR.

) -> torch.Tensor:

current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call self.alt_stream.wait_stream(current_stream) serializes the execution on the alternate stream with the current stream. This means the operations inside the with torch.cuda.stream(self.alt_stream): block will only start after _forward_shared_experts on the current stream has completed. This defeats the purpose of using a dual-stream approach, which is to achieve parallelism and overlap computation. Removing this line will allow _forward_shared_experts and the operations on the alternate stream (gate, topk, experts) to run concurrently, restoring the intended performance improvement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_stream here doesn’t serialize the dual-stream overlap; it mirrors the original code’s ordering. The call alt_stream.wait_stream(current_stream) is placed before we enqueue shared_output = _forward_shared_experts on the current stream. That means alt stream only waits for prior work that produced hidden_states, not for shared_experts itself. After that, shared_experts runs on the current stream while gate/topk/experts run on the alt stream, so they still overlap. The later current_stream.wait_stream(self.alt_stream) is necessary before the final add. If we drop the initial wait_stream, we risk alt stream reading hidden_states before upstream ops on the current stream have finished.

voipmonitor and others added 2 commits November 23, 2025 01:48
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@voipmonitor
Copy link
Contributor Author

@zRzRzRzRzRzRzR - would you please review? Is it ok to restore the non-DeepEP dual-stream MoE path?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant