Skip to content

Revert "[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check (#19190)"#19581

Merged
hnyls2002 merged 1 commit intomainfrom
revert/sync-patch-19190
Mar 1, 2026
Merged

Revert "[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check (#19190)"#19581
hnyls2002 merged 1 commit intomainfrom
revert/sync-patch-19190

Conversation

@alisonshao
Copy link
Copy Markdown
Collaborator

@alisonshao alisonshao commented Feb 28, 2026

Motivation

Revert #19190 to verify it is the root cause of the test_swa_radix_cache_kl CI flake on stage-b-test-large-1-gpu.

Bisect results (on H200):

Commit decode KL prefill KL Pass?
8c0f2d40b (before #19190) 0.00086 0.00123 PASS
b5a8e4179 (#19190) 0.00272 0.00242 FAIL

The sync patch removes GPU synchronization points in logits_processor.py (pin_memory + non_blocking transfers, output_size in repeat_interleave). This roughly doubles the baseline KL divergence from ~0.001 to ~0.002+, pushing it past the test threshold of 0.002.

CI failure: https://github.com/sgl-project/sglang/actions/runs/22514914602/job/65230926406

Test plan

  • Verify test_swa_radix_cache_kl passes on stage-b-test-large-1-gpu

…P, disable cache reset in mem check (#19190)"

This reverts commit b5a8e41.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request reverts a set of performance optimizations introduced in a prior commit (#19190). The primary motivation is to diagnose and resolve a persistent CI test failure (test_swa_radix_cache_kl) that is suspected to be a side effect of the original optimizations. The reversion restores previous behavior related to GPU memory transfers and tensor operations, effectively reintroducing synchronization points and simplifying Mamba cache allocation logic to stabilize the CI pipeline.

Highlights

  • Reverted GPU Synchronization Optimizations: Reverted changes that removed GPU synchronization points, pin_memory, and non_blocking transfers in logits_processor.py and schedule_batch.py to address a CI flake.
  • Removed 'all_extend_in_batch' Property: The all_extend_in_batch property was removed from two_batch_overlap.py, schedule_batch.py, and forward_batch_info.py.
  • Simplified Mamba Cache Management: Streamlined the zero-initialization logic for Mamba cache in memory_pool.py and refactored the handling of Mamba indices and ping-pong track buffers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/batch_overlap/two_batch_overlap.py
    • Removed 'all_extend_in_batch' from the list of keys to filter.
  • python/sglang/srt/connector/init.py
    • Made the 'device' parameter mandatory in the 'create_remote_connector' function.
  • python/sglang/srt/layers/logits_processor.py
    • Reverted tensor creation for 'sample_indices' and 'input_logprob_indices' to remove 'pin_memory=True' and 'non_blocking=True' flags, instead directly specifying the device.
    • Removed the 'output_size' argument from 'torch.repeat_interleave' calls and adjusted 'pruned_lens' tensor creation to remove 'pin_memory=True' and 'non_blocking=True'.
  • python/sglang/srt/managers/schedule_batch.py
    • Removed the 'all_extend_in_batch' attribute from the 'ScheduleBatch' class and its usage in related methods.
    • Simplified the creation of 'mamba_track_indices' and 'mamba_track_mask' by directly assigning the device and removing 'pin_memory=True' and 'non_blocking=True'.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Simplified the zero-filling logic for Mamba cache tensors during allocation.
    • Refactored the handling of 'mamba_indices' and 'mamba_ping_pong_track_buffers' to use lists and then convert to tensors with explicit device assignment.
    • Streamlined the logic for freeing Mamba ping-pong track buffers.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Removed the 'all_extend_in_batch' attribute from the 'ForwardBatch' class and its initialization.
  • python/sglang/srt/model_executor/model_runner.py
    • Removed the 'empty_cache=False' argument from the 'get_available_gpu_memory' function call.
Activity
  • No specific activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@alisonshao
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-b-test-large-1-gpu

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-b-test-large-1-gpu to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reverts a previous set of changes aimed at improving performance by removing GPU synchronization points. The revert is motivated by CI test flakiness that appeared after the original optimizations. The changes re-introduce synchronization during tensor creation and data transfers by removing pin_memory=True and non_blocking=True options, and also remove an optimization in torch.repeat_interleave by dropping the output_size parameter. The logic for creating certain tensors, such as mamba_track_indices, is simplified, which may result in more CPU-GPU interaction. The field all_extend_in_batch is also removed across several files. This revert seems to be a reasonable step to restore stability.

@alisonshao
Copy link
Copy Markdown
Collaborator Author

@alisonshao
Copy link
Copy Markdown
Collaborator Author

tested on local machine:
10/10 passed, 0 failures. The revert fixes the test. Commit b5a8e41 is confirmed as the root cause.
Bash(ssh -o ConnectTimeout=10 radixark@124.158.103.4 "grep 'avg_kl_div=' ~/bisect_test/kl_test_10x.log" 2>&1)
avg_kl_div=0.0012711324111831731
avg_kl_div=0.0019385124790082355
avg_kl_div=0.0018853885049004882
avg_kl_div=0.0016162643623278288
avg_kl_div=0.0015568186609981084
avg_kl_div=0.0012923602902972216
avg_kl_div=0.0013439864366594325
avg_kl_div=0.0014263578029986405
avg_kl_div=0.001372562079944943
avg_kl_div=0.0012383510878357434

@hnyls2002 hnyls2002 merged commit a45613f into main Mar 1, 2026
103 of 139 checks passed
@hnyls2002 hnyls2002 deleted the revert/sync-patch-19190 branch March 1, 2026 03:46
alisonshao pushed a commit that referenced this pull request Mar 1, 2026
Re-applies #19190 (reverted in #19581) but excludes the logits_processor.py
changes that caused KL divergence regression in test_swa_radix_cache_kl.

The logits_processor changes switched sample_indices, input_logprob_indices,
and pruned_lens from synchronous device creation to pin_memory + non_blocking
transfers, and added output_size to repeat_interleave. This removed implicit
GPU sync points that changed numerical behavior, roughly doubling baseline
KL divergence from ~0.001 to ~0.002+ and causing CI flakes.

All other optimizations are preserved:
- Mamba track indices: GPU-only construction without scalar extraction
- Mamba cache zeroing: expand-from-scalar pattern (no CPU-GPU sync)
- Ping-pong buffer: avoid Python-list advanced indexing on device tensors
- all_extend_in_batch field propagation
- connector: device parameter default
- model_runner: empty_cache=False for memory check
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction
- Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
- Replace Python-list advanced indexing in free_mamba_cache with integer slicing
- Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync
- Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
- Add all_extend_in_batch field for prefill cudagraph with DP attention
- Default device=None in create_remote_connector
- Avoid unnecessary cache clearing in weight update logging

Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py
changes that caused SWA KL test regression. Mamba decode vectorization
from internal PR.
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
…nsfers

- Use pin_memory + non_blocking for sample_indices and input_logprob_indices
  in _get_pruned_states to avoid CPU-GPU sync
- Use pin_memory + non_blocking for pruned_lens in _expand_metadata_for_logprobs
- Add output_size to repeat_interleave calls to avoid implicit device sync

Note: This was split from sgl-project#19190 (reverted in sgl-project#19581) because
these changes caused SWA KL test regression. Landing separately
to allow independent validation.
YazhiGao added a commit to YazhiGao/sglang that referenced this pull request Mar 2, 2026
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction
- Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
- Replace Python-list advanced indexing in free_mamba_cache with integer slicing
- Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync
- Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
- Add all_extend_in_batch field for prefill cudagraph with DP attention
- Default device=None in create_remote_connector
- Avoid unnecessary cache clearing in weight update logging

Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py
changes that caused SWA KL test regression. Mamba decode vectorization
from internal PR.
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581)

Co-authored-by: Alison Shao <alisonshao@mac.lan>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581)

Co-authored-by: Alison Shao <alisonshao@mac.lan>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants