Skip to content

Remove sync points in mamba cache + prefill cudagraph plumbing for DP#19639

Merged
ispobock merged 1 commit intosgl-project:mainfrom
YazhiGao:sync-safe
Mar 2, 2026
Merged

Remove sync points in mamba cache + prefill cudagraph plumbing for DP#19639
ispobock merged 1 commit intosgl-project:mainfrom
YazhiGao:sync-safe

Conversation

@YazhiGao
Copy link
Copy Markdown
Contributor

@YazhiGao YazhiGao commented Mar 2, 2026

Summary

  • Vectorize mamba_track_indices construction via torch.gather instead of per-request scalar extraction in prepare_for_decode
  • Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
  • Replace Python-list advanced indexing in free_mamba_cache with integer slicing
  • Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync on cache clear
  • Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
  • Add all_extend_in_batch field plumbing for prefill cudagraph with DP attention
  • Default device=None in create_remote_connector
  • Avoid unnecessary torch.cuda.empty_cache() in weight update diagnostic logging

Split from #19190 (reverted in #19581): this PR excludes the logits_processor.py changes that caused the SWA KL test regression.

Test plan

  • Existing mamba tests pass
  • SWA KL test (test_swa_radix_cache_kl) passes (this PR does not touch logits_processor.py)
  • No regression in decode throughput

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on significant performance enhancements and synchronization reductions within the Mamba cache management system. By vectorizing key operations and optimizing memory allocation strategies, it aims to improve sampling and end-to-end MFU. Additionally, it introduces necessary plumbing for prefill cudagraphs with data parallel attention, streamlining the batch processing logic. The changes also include minor code cleanups to prevent unnecessary overhead.

Highlights

  • Mamba Cache Optimization: Vectorized mamba_track_indices construction using torch.gather and mamba_track_mask via tensor arithmetic, replacing per-request scalar extraction and Python list comprehensions to improve performance.
  • Reduced CPU-GPU Sync: Implemented a GPU zero-expand pattern in MambaPool.alloc to clear cache without implicit CPU-GPU synchronization, and retained tensor references in HybridReqToTokenPool.alloc to avoid tolist() roundtrips.
  • Prefill Cudagraph Plumbing: Added all_extend_in_batch field plumbing to support prefill cudagraphs with Data Parallel (DP) attention.
  • Code Cleanups and Minor Improvements: Replaced Python-list advanced indexing with integer slicing in free_mamba_cache, defaulted device=None in create_remote_connector, and prevented unnecessary torch.cuda.empty_cache() calls during weight update diagnostic logging.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/batch_overlap/two_batch_overlap.py
    • Added 'all_extend_in_batch' to the list of keys filtered for batch processing.
  • python/sglang/srt/connector/init.py
    • Modified create_remote_connector to set device parameter default to None.
  • python/sglang/srt/managers/schedule_batch.py
    • Added all_extend_in_batch attribute to ScheduleBatch class.
    • Vectorized the creation of mamba_track_indices using torch.gather for efficiency.
    • Vectorized the creation of mamba_track_mask using tensor arithmetic and asynchronous host-to-device transfer.
    • Passed all_extend_in_batch to get_model_worker_batch and copy methods.
    • Added all_extend_in_batch to the ModelWorkerBatch class definition.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Updated MambaPool.alloc to use GPU zero-expansion for clearing cache, avoiding CPU-GPU synchronization.
    • Modified HybridReqToTokenPool.alloc to store tensor references directly instead of converting to and from Python lists.
    • Replaced Python list-based advanced indexing with integer slicing for freeing Mamba cache entries in free_mamba_cache.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added all_extend_in_batch attribute to ForwardBatch class.
    • Included all_extend_in_batch when initializing new ForwardBatch instances.
  • python/sglang/srt/model_executor/model_runner.py
    • Modified update_weights_from_disk to prevent torch.cuda.empty_cache() from being called during memory logging, avoiding unnecessary cache clearing.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance optimizations, primarily for Mamba models, by removing synchronization points, vectorizing operations, and improving memory management. The changes, such as using torch.gather for vectorized indexing and avoiding GPU-CPU roundtrips, are well-implemented and should lead to noticeable performance gains. The plumbing for all_extend_in_batch is also correctly integrated. Overall, this is a solid set of improvements.

Comment on lines 347 to +357
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for creating a zero tensor and assigning it to a slice of the cache is duplicated for self.mamba_cache.conv and self.mamba_cache.temporal. You can refactor this into a single loop to improve code clarity and reduce duplication.

Suggested change
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
for t in self.mamba_cache.conv + [self.mamba_cache.temporal]:
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z

- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction
- Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension
- Replace Python-list advanced indexing in free_mamba_cache with integer slicing
- Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync
- Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip
- Add all_extend_in_batch field for prefill cudagraph with DP attention
- Default device=None in create_remote_connector
- Avoid unnecessary cache clearing in weight update logging

Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py
changes that caused SWA KL test regression. Mamba decode vectorization
from internal PR.
@YazhiGao
Copy link
Copy Markdown
Contributor Author

YazhiGao commented Mar 2, 2026

/tag-and-rerun-ci

1 similar comment
@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 2, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Mar 2, 2026
@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 2, 2026

@ispobock ispobock merged commit 07ef5f7 into sgl-project:main Mar 2, 2026
295 of 332 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants