Skip to content

[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check#19190

Merged
hnyls2002 merged 2 commits intosgl-project:mainfrom
YazhiGao:leon/sync_srt_20260223
Feb 28, 2026
Merged

[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check#19190
hnyls2002 merged 2 commits intosgl-project:mainfrom
YazhiGao:leon/sync_srt_20260223

Conversation

@YazhiGao
Copy link
Copy Markdown
Contributor

@YazhiGao YazhiGao commented Feb 23, 2026

Summary

  • Remove CPU-GPU sync points in logits_processor, schedule_batch, and memory_pool: use pin_memory=True + .to(device, non_blocking=True) instead of direct device tensor construction, use output_size in repeat_interleave to avoid implicit sync, build mamba track indices on GPU without scalar extraction, use GPU zero-expand instead of index_fill_
  • Add all_extend_in_batch field in schedule_batch, forward_batch_info, and ModelWorkerBatch to support prefill CUDA graph capture with DP attention
  • Minor: connector device=None default, empty_cache=False in weight update logging

Test plan

  • Verify no regressions on existing CI tests
  • Test with DP attention enabled

@YazhiGao YazhiGao changed the title [TML] Remove sync points, prefill cudagraph for DP, KL0 mode [SGL] sync patch: Remove sync points, prefill cudagraph for DP, KL0 mode Feb 23, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @YazhiGao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the SGLang runtime by removing unnecessary synchronization points, adding support for CUDA graph prefilling, and implementing a deterministic KL0 mode. These changes aim to improve performance and determinism in distributed training scenarios.

Highlights

  • Synchronization Removal: Removed CPU-GPU synchronization points in logits_processor, schedule_batch, and memory_pool by using pin_memory=True and .to(device, non_blocking=True) for data transfers and avoiding implicit synchronization in repeat_interleave.
  • CUDA Graph Prefill Support: Added all_extend_in_batch field to schedule_batch, forward_batch_info, and ModelWorkerBatch to support prefill CUDA graph capture with DP attention.
  • KL0 Mode Implementation: Implemented KL0 mode in parallel_state using c10d functional all_reduce/reduce_scatter for deterministic reductions matching the learner.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/connector/init.py
    • Modified default device argument for create_remote_connector.
  • python/sglang/srt/distributed/parallel_state.py
    • Implemented KL0 mode using c10d functional all_reduce and reduce_scatter for deterministic reductions.
  • python/sglang/srt/layers/logits_processor.py
    • Replaced direct device tensor construction with pin_memory=True and .to(device, non_blocking=True) to remove sync points.
    • Avoided implicit device sync inside repeat_interleave by providing output_size.
  • python/sglang/srt/managers/schedule_batch.py
    • Added all_extend_in_batch field to support prefill CUDA graph capture with DP attention.
    • Refactored mamba track indices construction to be fully on GPU, avoiding scalar extraction.
    • Used pin_memory for mamba track mask construction.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Replaced index_fill_ with GPU zero-expand to avoid sync points during memory allocation.
    • Refactored mamba cache allocation to use torch.stack instead of list operations.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added all_extend_in_batch field to support prefill CUDA graph capture with DP attention.
  • python/sglang/srt/model_executor/model_runner.py
    • Modified weight update logging to disable empty_cache for memory availability check.
Activity
  • The author, YazhiGao, has made changes to remove sync points, prefill cudagraph for DP, and add KL0 mode.
  • The test plan includes verifying no regressions on existing CI tests and testing with DP attention enabled.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@YazhiGao YazhiGao force-pushed the leon/sync_srt_20260223 branch from bac319b to 8c1c97c Compare February 23, 2026 18:57
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request focuses on optimizing GPU-CPU synchronization and enabling CUDA graph prefill for DP attention, along with introducing a KL0 mode for deterministic reductions. The changes involve modifying tensor creation and data transfer methods to reduce sync points, adding fields to support CUDA graph capture, and implementing a new reduction path for KL0 mode. Overall, the changes aim to improve performance and determinism.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/srt/connector/init.py (25)

medium

The device parameter is given a default value of None, but its usage within the function (e.g., RedisConnector(url)) does not seem to utilize this parameter. If device is not used, it should be removed to maintain clean API design and avoid confusion.

def create_remote_connector(url, **kwargs) -> BaseConnector:

python/sglang/srt/distributed/parallel_state.py (551)

medium

The import statement for is_kl0_mode_enabled is placed inside the all_reduce method. This can lead to repeated imports every time the method is called, potentially impacting performance. It's better to move imports to the top of the file to ensure they are executed only once.

from sglang.private.utils.torch_utils import is_kl0_mode_enabled

        if is_kl0_mode_enabled():

python/sglang/srt/distributed/parallel_state.py (688)

medium

Similar to the previous comment, the import statement for is_kl0_mode_enabled is placed inside the _reduce_scatter_tensor method. This can lead to repeated imports every time the method is called, potentially impacting performance. It's better to move imports to the top of the file to ensure they are executed only once.

from sglang.private.utils.torch_utils import is_kl0_mode_enabled

        if is_kl0_mode_enabled():

python/sglang/srt/managers/schedule_batch.py (2018-2031)

medium

The logic for constructing self.mamba_track_indices is complex and involves conditional slicing within a list comprehension, followed by torch.cat. This could be simplified by first preparing the list of tensors and then concatenating them, or by using a more direct tensor operation if possible. The current approach is hard to read and debug.

            if len(self.reqs) == 0:
                self.mamba_track_indices = torch.empty(
                    (0,), dtype=torch.int64, device=self.device
                )
            else:
                mamba_track_tensors = []
                for req in self.reqs:
                    if req.mamba_next_track_idx == 1:
                        mamba_track_tensors.append(req.mamba_ping_pong_track_buffer[1:])
                    else:
                        mamba_track_tensors.append(req.mamba_ping_pong_track_buffer[:1])
                self.mamba_track_indices = torch.cat(mamba_track_tensors, dim=0).to(torch.int64)

python/sglang/srt/mem_cache/memory_pool.py (533-547)

medium

The change from creating a tensor from a Python list (torch.tensor(mamba_index, ...)) to stacking existing tensors (torch.stack(mamba_indices)) is a significant improvement. This avoids a CPU-GPU synchronization point and keeps the operations on the GPU, which is beneficial for performance. Similarly, for mamba_ping_pong_track_buffers, stacking tensors is more efficient than converting a list to a tensor.

        mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
        self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
        if self.enable_mamba_extra_buffer:
            ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to(dtype=torch.int32)
            self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = ping_pong_tensor

@YazhiGao YazhiGao changed the title [SGL] sync patch: Remove sync points, prefill cudagraph for DP, KL0 mode [SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check Feb 23, 2026
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

Could you please fix the lint? Thanks!

- Remove CPU-GPU sync points in logits_processor, schedule_batch, memory_pool
  (pin_memory + non_blocking transfers, GPU-side index construction)
- Add all_extend_in_batch field for prefill CUDA graph with DP attention
- Minor: connector device default, weight update logging fix

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@YazhiGao YazhiGao force-pushed the leon/sync_srt_20260223 branch from e1918f3 to be22b6e Compare February 25, 2026 05:18
@ispobock
Copy link
Copy Markdown
Collaborator

@hnyls2002 hnyls2002 merged commit b5a8e41 into sgl-project:main Feb 28, 2026
205 of 216 checks passed
alisonshao pushed a commit that referenced this pull request Feb 28, 2026
…P, disable cache reset in mem check (#19190)"

This reverts commit b5a8e41.
hnyls2002 pushed a commit that referenced this pull request Mar 1, 2026
…P, disable cache reset in mem check (#19190)" (#19581)

Co-authored-by: Alison Shao <alisonshao@mac.lan>
alisonshao pushed a commit that referenced this pull request Mar 1, 2026
Re-applies #19190 (reverted in #19581) but excludes the logits_processor.py
changes that caused KL divergence regression in test_swa_radix_cache_kl.

The logits_processor changes switched sample_indices, input_logprob_indices,
and pruned_lens from synchronous device creation to pin_memory + non_blocking
transfers, and added output_size to repeat_interleave. This removed implicit
GPU sync points that changed numerical behavior, roughly doubling baseline
KL divergence from ~0.001 to ~0.002+ and causing CI flakes.

All other optimizations are preserved:
- Mamba track indices: GPU-only construction without scalar extraction
- Mamba cache zeroing: expand-from-scalar pattern (no CPU-GPU sync)
- Ping-pong buffer: avoid Python-list advanced indexing on device tensors
- all_extend_in_batch field propagation
- connector: device parameter default
- model_runner: empty_cache=False for memory check
alisonshao pushed a commit that referenced this pull request Mar 1, 2026
…le cache reset in mem check (#19190)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction
- Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
- Replace Python-list advanced indexing in free_mamba_cache with integer slicing
- Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync
- Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
- Add all_extend_in_batch field for prefill cudagraph with DP attention
- Default device=None in create_remote_connector
- Avoid unnecessary cache clearing in weight update logging

Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py
changes that caused SWA KL test regression. Mamba decode vectorization
from internal PR.
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
…nsfers

- Use pin_memory + non_blocking for sample_indices and input_logprob_indices
  in _get_pruned_states to avoid CPU-GPU sync
- Use pin_memory + non_blocking for pruned_lens in _expand_metadata_for_logprobs
- Add output_size to repeat_interleave calls to avoid implicit device sync

Note: This was split from sgl-project#19190 (reverted in sgl-project#19581) because
these changes caused SWA KL test regression. Landing separately
to allow independent validation.
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction
- Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
- Replace Python-list advanced indexing in free_mamba_cache with integer slicing
- Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync
- Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
- Add all_extend_in_batch field for prefill cudagraph with DP attention
- Default device=None in create_remote_connector
- Avoid unnecessary cache clearing in weight update logging

Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py
changes that caused SWA KL test regression. Mamba decode vectorization
from internal PR.
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
…le cache reset in mem check (sgl-project#19190)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581)

Co-authored-by: Alison Shao <alisonshao@mac.lan>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…le cache reset in mem check (sgl-project#19190)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581)

Co-authored-by: Alison Shao <alisonshao@mac.lan>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants