[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check#19190
Conversation
Summary of ChangesHello @YazhiGao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the SGLang runtime by removing unnecessary synchronization points, adding support for CUDA graph prefilling, and implementing a deterministic KL0 mode. These changes aim to improve performance and determinism in distributed training scenarios. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
bac319b to
8c1c97c
Compare
There was a problem hiding this comment.
Code Review
The pull request focuses on optimizing GPU-CPU synchronization and enabling CUDA graph prefill for DP attention, along with introducing a KL0 mode for deterministic reductions. The changes involve modifying tensor creation and data transfer methods to reduce sync points, adding fields to support CUDA graph capture, and implementing a new reduction path for KL0 mode. Overall, the changes aim to improve performance and determinism.
I am having trouble creating individual review comments. Click here to see my feedback.
python/sglang/srt/connector/init.py (25)
The device parameter is given a default value of None, but its usage within the function (e.g., RedisConnector(url)) does not seem to utilize this parameter. If device is not used, it should be removed to maintain clean API design and avoid confusion.
def create_remote_connector(url, **kwargs) -> BaseConnector:
python/sglang/srt/distributed/parallel_state.py (551)
The import statement for is_kl0_mode_enabled is placed inside the all_reduce method. This can lead to repeated imports every time the method is called, potentially impacting performance. It's better to move imports to the top of the file to ensure they are executed only once.
from sglang.private.utils.torch_utils import is_kl0_mode_enabled
if is_kl0_mode_enabled():python/sglang/srt/distributed/parallel_state.py (688)
Similar to the previous comment, the import statement for is_kl0_mode_enabled is placed inside the _reduce_scatter_tensor method. This can lead to repeated imports every time the method is called, potentially impacting performance. It's better to move imports to the top of the file to ensure they are executed only once.
from sglang.private.utils.torch_utils import is_kl0_mode_enabled
if is_kl0_mode_enabled():python/sglang/srt/managers/schedule_batch.py (2018-2031)
The logic for constructing self.mamba_track_indices is complex and involves conditional slicing within a list comprehension, followed by torch.cat. This could be simplified by first preparing the list of tensors and then concatenating them, or by using a more direct tensor operation if possible. The current approach is hard to read and debug.
if len(self.reqs) == 0:
self.mamba_track_indices = torch.empty(
(0,), dtype=torch.int64, device=self.device
)
else:
mamba_track_tensors = []
for req in self.reqs:
if req.mamba_next_track_idx == 1:
mamba_track_tensors.append(req.mamba_ping_pong_track_buffer[1:])
else:
mamba_track_tensors.append(req.mamba_ping_pong_track_buffer[:1])
self.mamba_track_indices = torch.cat(mamba_track_tensors, dim=0).to(torch.int64)python/sglang/srt/mem_cache/memory_pool.py (533-547)
The change from creating a tensor from a Python list (torch.tensor(mamba_index, ...)) to stacking existing tensors (torch.stack(mamba_indices)) is a significant improvement. This avoids a CPU-GPU synchronization point and keeps the operations on the GPU, which is beneficial for performance. Similarly, for mamba_ping_pong_track_buffers, stacking tensors is more efficient than converting a list to a tensor.
mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
if self.enable_mamba_extra_buffer:
ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to(dtype=torch.int32)
self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = ping_pong_tensor
|
/tag-and-rerun-ci |
|
Could you please fix the lint? Thanks! |
- Remove CPU-GPU sync points in logits_processor, schedule_batch, memory_pool (pin_memory + non_blocking transfers, GPU-side index construction) - Add all_extend_in_batch field for prefill CUDA graph with DP attention - Minor: connector device default, weight update logging fix Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
e1918f3 to
be22b6e
Compare
|
These two CIs cannot pass:
Need to double check if they are related to this change |
Re-applies #19190 (reverted in #19581) but excludes the logits_processor.py changes that caused KL divergence regression in test_swa_radix_cache_kl. The logits_processor changes switched sample_indices, input_logprob_indices, and pruned_lens from synchronous device creation to pin_memory + non_blocking transfers, and added output_size to repeat_interleave. This removed implicit GPU sync points that changed numerical behavior, roughly doubling baseline KL divergence from ~0.001 to ~0.002+ and causing CI flakes. All other optimizations are preserved: - Mamba track indices: GPU-only construction without scalar extraction - Mamba cache zeroing: expand-from-scalar pattern (no CPU-GPU sync) - Ping-pong buffer: avoid Python-list advanced indexing on device tensors - all_extend_in_batch field propagation - connector: device parameter default - model_runner: empty_cache=False for memory check
…le cache reset in mem check (#19190) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: ispobock <ispobaoke@gmail.com>
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction - Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension - Replace Python-list advanced indexing in free_mamba_cache with integer slicing - Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync - Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip - Add all_extend_in_batch field for prefill cudagraph with DP attention - Default device=None in create_remote_connector - Avoid unnecessary cache clearing in weight update logging Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py changes that caused SWA KL test regression. Mamba decode vectorization from internal PR.
…nsfers - Use pin_memory + non_blocking for sample_indices and input_logprob_indices in _get_pruned_states to avoid CPU-GPU sync - Use pin_memory + non_blocking for pruned_lens in _expand_metadata_for_logprobs - Add output_size to repeat_interleave calls to avoid implicit device sync Note: This was split from sgl-project#19190 (reverted in sgl-project#19581) because these changes caused SWA KL test regression. Landing separately to allow independent validation.
- Vectorize mamba_track_indices via torch.gather instead of per-req scalar extraction - Vectorize mamba_track_mask via tensor arithmetic instead of Python list comprehension - Replace Python-list advanced indexing in free_mamba_cache with integer slicing - Use GPU zero-expand pattern in MambaPool.alloc to avoid implicit CPU-GPU sync - Keep tensor references in HybridReqToTokenPool.alloc instead of .tolist() roundtrip - Add all_extend_in_batch field for prefill cudagraph with DP attention - Default device=None in create_remote_connector - Avoid unnecessary cache clearing in weight update logging Split from sgl-project#19190 (reverted in sgl-project#19581): excludes logits_processor.py changes that caused SWA KL test regression. Mamba decode vectorization from internal PR.
…le cache reset in mem check (sgl-project#19190) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: ispobock <ispobaoke@gmail.com>
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581) Co-authored-by: Alison Shao <alisonshao@mac.lan>
…le cache reset in mem check (sgl-project#19190) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: ispobock <ispobaoke@gmail.com>
…P, disable cache reset in mem check (sgl-project#19190)" (sgl-project#19581) Co-authored-by: Alison Shao <alisonshao@mac.lan>
Summary
pin_memory=True+.to(device, non_blocking=True)instead of direct device tensor construction, useoutput_sizeinrepeat_interleaveto avoid implicit sync, build mamba track indices on GPU without scalar extraction, use GPU zero-expand instead ofindex_fill_all_extend_in_batchfield in schedule_batch, forward_batch_info, and ModelWorkerBatch to support prefill CUDA graph capture with DP attentiondevice=Nonedefault,empty_cache=Falsein weight update loggingTest plan