Skip to content

[JIT Kernel][Feature] Support JIT custom all reduce (rewrite as v2)#19880

Merged
BBuf merged 20 commits intosgl-project:mainfrom
DarkSharpness:jit_custom_all_reduce
Mar 20, 2026
Merged

[JIT Kernel][Feature] Support JIT custom all reduce (rewrite as v2)#19880
BBuf merged 20 commits intosgl-project:mainfrom
DarkSharpness:jit_custom_all_reduce

Conversation

@DarkSharpness
Copy link
Copy Markdown
Collaborator

@DarkSharpness DarkSharpness commented Mar 4, 2026

Motivation

Modifications

This PR implements a clean version of custom all reduce which is highly configurable (we can set the number SMs, the recommended CTA size). We also integrate post-hopper features like PDL into the custom-all-reduce, which improves the latency by up to 40% in small batch sizes.

We also implement a push-mode 1 shot all reduce, which is significantly faster than pull-mode under small batch sizes.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new JIT-compiled custom all-reduce implementation (version 2) for SGLang, designed to enhance distributed communication performance, particularly for intra-node GPU setups. It provides a flexible and optimized alternative to existing all-reduce methods, leveraging CUDA IPC and custom kernels for efficient data exchange and synchronization.

Highlights

  • JIT Custom All-Reduce Implementation: Introduced a new JIT-compiled custom all-reduce mechanism with dedicated Python bindings (all_reduce.py) and CUDA C++ kernels (custom_all_reduce.cuh) to optimize distributed communication.
  • Distributed Communication Infrastructure: Added new CUDA C++ header files (all_reduce.cuh, common.cuh, ffi.h) that define essential structures and utilities for inter-GPU communication, including IPC memory handling, semaphores, and FFI tensor management.
  • Integration with SGLang's Distributed Runtime: Modified the dispatch_custom_allreduce function to conditionally enable the new JIT-compiled CustomAllReduceV2 via the SGLANG_USE_JIT_ALL_REDUCE environment variable, allowing for flexible adoption.
  • Code Formatting and Utilities: Updated the .clang-format configuration to support the new header file structure and added a div_ceil utility function to utils.cuh.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/.clang-format
    • Updated include category regex to be more general.
  • python/sglang/jit_kernel/all_reduce.py
    • Added new Python module for JIT custom all-reduce, including CustomAllReduceObj and related functions.
  • python/sglang/jit_kernel/csrc/distributed/custom_all_reduce.cuh
    • Added CUDA C++ kernel implementing one-shot and two-shot all-reduce operations with PDL support.
  • python/sglang/jit_kernel/include/sgl_kernel/distributed/all_reduce.cuh
    • Added C++ header defining CustomAllReduceBase for managing distributed all-reduce state and IPC memory handles.
  • python/sglang/jit_kernel/include/sgl_kernel/distributed/common.cuh
    • Added C++ header for distributed communication common utilities, including Semaphore and Controller for inter-GPU synchronization.
  • python/sglang/jit_kernel/include/sgl_kernel/ffi.h
    • Added C++ header for FFI utilities, including empty, empty_like, and from_blob for TVM tensor creation.
  • python/sglang/jit_kernel/include/sgl_kernel/utils.cuh
    • Added a div_ceil utility function for ceiling division.
  • python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
    • Modified the dispatch logic to optionally use the new JIT-compiled CustomAllReduceV2 based on an environment variable.
  • python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py
    • Implemented a new Python class CustomAllReduceV2 to interface with the JIT-compiled custom all-reduce, supporting CUDA graph capture.
Activity
  • The pull request is marked as 'Work In Progress' (WIP), indicating that development is ongoing.
  • The description includes standard contribution guidelines and a checklist for formatting, unit tests, documentation, and benchmarking, suggesting the author is following established development practices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a JIT-compiled custom all-reduce implementation (v2) as an opt-in feature, including new Python bindings, CUDA kernels, and host-side control logic. It supports both one-shot and two-shot algorithms, integrates with CUDA graph capturing, and is enabled via the SGLANG_USE_JIT_ALL_REDUCE environment variable. However, a significant security vulnerability was identified: memory offsets used for CUDA graph registration are truncated from 64-bit to 32-bit integers during inter-process communication. This could lead to incorrect memory access, memory corruption, or information disclosure on the GPU. It is strongly recommended to use 64-bit integers for all memory-related offsets. Additionally, the code review focused on improving clarity, maintainability, and configuration flexibility, with suggestions for clarifying comments, removing unused code, and making hardcoded parameters configurable.

Comment thread python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py Outdated
Comment thread python/sglang/jit_kernel/all_reduce.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py Outdated
@DarkSharpness DarkSharpness changed the title [WIP][Feature] Support JIT custom all reduce [WIP][JIT Kernel][Feature] Support JIT custom all reduce Mar 5, 2026
@DarkSharpness DarkSharpness marked this pull request as ready for review March 7, 2026 03:49
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@DarkSharpness DarkSharpness changed the title [WIP][JIT Kernel][Feature] Support JIT custom all reduce [JIT Kernel][Feature] Support JIT custom all reduce (rewrite as v2) Mar 7, 2026
@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Mar 7, 2026
@DarkSharpness DarkSharpness force-pushed the jit_custom_all_reduce branch 2 times, most recently from 3256786 to 00be07d Compare March 11, 2026 08:31
@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

cc @BBuf @yuan-luo @HydraQYH . For now we implement a push-mode 1-shot all reduce and normal pull-mode 1/2-shot all-reduce, which can be significantly faster than AOT custom all-reduce.

Currently the SGLANG_USE_JIT_ALL_REDUCE is default to false (so no change to default behavior), and we plan to turn on it by default later once we finish some more optimizations (e.g. norm fusion).

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

Some performance results for TP=4 on H200/B200

H200
Size NCCL AOT JIT FI AOT/NCCL JIT/NCCL FI/NCCL
4K 21.5 4.4 2.4 2.8 4.92x 9.04x 7.74x
16K 22.6 4.6 2.5 2.9 4.95x 8.97x 7.85x
64K 23.0 5.0 2.9 3.3 4.58x 7.83x 6.88x
128K 23.5 5.5 3.6 3.9 4.27x 6.52x 5.96x
192K 23.7 6.2 4.2 4.5 3.84x 5.61x 5.22x
256K 23.6 6.7 4.8 5.1 3.53x 4.87x 4.60x
384K 23.9 8.3 6.2 6.4 2.90x 3.85x 3.72x
512K 26.5 10.2 6.9 7.6 2.59x 3.84x 3.50x
640K 26.8 10.6 7.3 9.3 2.53x 3.67x 2.89x
768K 27.3 11.0 8.1 10.3 2.47x 3.37x 2.66x
896K 29.3 11.3 8.9 11.4 2.59x 3.29x 2.58x
1M 29.9 11.7 9.6 12.4 2.56x 3.12x 2.41x
2M 39.3 17.5 14.4 23.4 2.25x 2.74x 1.68x
3M 49.8 23.6 19.4 35.1 2.11x 2.57x 1.42x
4M 52.4 29.2 23.8 47.2 1.80x 2.20x 1.11x
8M 83.7 52.4 43.7 96.9 1.60x 1.92x 0.86x
16M 119.3 95.5 83.7 112.3 1.25x 1.43x 1.06x
32M 193.8 182.2 161.9 234.5 1.06x 1.20x 0.83x
B200
Size NCCL AOT JIT FI AOT/NCCL JIT/NCCL FI/NCCL
4K 23.5 6.5 3.3 3.9 3.61x 7.18x 6.05x
16K 24.3 6.7 3.4 4.0 3.61x 7.23x 6.05x
64K 26.2 7.0 3.7 4.1 3.76x 7.10x 6.45x
128K 28.0 7.2 3.9 4.3 3.86x 7.12x 6.50x
192K 28.7 7.5 4.1 4.5 3.83x 6.93x 6.36x
256K 28.4 7.9 4.4 4.9 3.59x 6.42x 5.83x
384K 29.0 9.9 5.2 5.5 2.92x 5.56x 5.27x
512K 28.7 17.7 5.8 6.1 1.62x 4.93x 4.70x
640K 30.5 17.8 6.7 7.1 1.72x 4.55x 4.29x
768K 30.9 17.9 7.5 7.7 1.72x 4.14x 3.99x
896K 30.7 18.0 7.9 8.2 1.70x 3.91x 3.73x
1M 31.1 18.2 8.5 8.8 1.71x 3.67x 3.53x
2M 36.2 26.6 14.0 14.8 1.36x 2.59x 2.45x
3M 41.8 35.2 18.3 21.3 1.19x 2.29x 1.97x
4M 48.1 43.3 21.9 27.4 1.11x 2.20x 1.76x
8M 58.0 74.0 31.3 56.6 0.78x 1.86x 1.02x
16M 88.8 131.0 49.8 78.4 0.68x 1.78x 1.13x
32M 131.3 252.3 87.1 143.4 0.52x 1.51x 0.92x

Comment thread python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py Outdated
Comment thread python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py Outdated
Comment thread python/sglang/jit_kernel/tests/test_custom_all_reduce.py Outdated
- Regex: '^<sgl_kernel/.*\.h>$'
Priority: 0
- Regex: '^<sgl_kernel/impl/.*>$'
- Regex: '^<sgl_kernel/.*/.*>$'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to update this regex?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there's more secondary headers in JIT kernel. In this PR we introduce <sgl_kernel/distributed/xxx.cuh> . This rule can work for all of them and there's no break for existing code.

Comment thread python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py
Comment thread python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py Outdated
Comment thread python/sglang/jit_kernel/tests/test_custom_all_reduce.py Outdated
graph = torch.cuda.CUDAGraph()
graph_inp = torch.zeros((TEST_LAYERS, size), dtype=dtype, device=device)
out_jits = []
with comm.capture():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you plan to handle CUDA graph compatibility for the pull-based custom all-reduce path in real LLM runs? It seems this path depends on the extra comm.capture() address-registration flow, so I’m not sure what the intended graph capture / recapture lifecycle is here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AOT custom all reduce already uses a similar graph register method. We just follow their design

def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

DarkSharpness commented Mar 14, 2026

I'd like to know how peak memory usage differs between push mode and pull mode. Could the benchmark include this data?

The buffer memory usage of is 2 * world_size * push_buffer_size + pull_buffer_size. By default, it will not exceed 32 MB. Other controll buffers may take up to around 12MB. This should be small enough (FYI, flashinfer workspace may easily consume 128MB).

@yuan-luo
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 18, 2026

Could you paste the benchmark results for this new CA kernel? That would be great.

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

Some performance results for TP=4 on H200/B200

H200
Size NCCL AOT JIT FI AOT/NCCL JIT/NCCL FI/NCCL
4K 21.5 4.4 2.4 2.8 4.92x 9.04x 7.74x
16K 22.6 4.6 2.5 2.9 4.95x 8.97x 7.85x
64K 23.0 5.0 2.9 3.3 4.58x 7.83x 6.88x
128K 23.5 5.5 3.6 3.9 4.27x 6.52x 5.96x
192K 23.7 6.2 4.2 4.5 3.84x 5.61x 5.22x
256K 23.6 6.7 4.8 5.1 3.53x 4.87x 4.60x
384K 23.9 8.3 6.2 6.4 2.90x 3.85x 3.72x
512K 26.5 10.2 6.9 7.6 2.59x 3.84x 3.50x
640K 26.8 10.6 7.3 9.3 2.53x 3.67x 2.89x
768K 27.3 11.0 8.1 10.3 2.47x 3.37x 2.66x
896K 29.3 11.3 8.9 11.4 2.59x 3.29x 2.58x
1M 29.9 11.7 9.6 12.4 2.56x 3.12x 2.41x
2M 39.3 17.5 14.4 23.4 2.25x 2.74x 1.68x
3M 49.8 23.6 19.4 35.1 2.11x 2.57x 1.42x
4M 52.4 29.2 23.8 47.2 1.80x 2.20x 1.11x
8M 83.7 52.4 43.7 96.9 1.60x 1.92x 0.86x
16M 119.3 95.5 83.7 112.3 1.25x 1.43x 1.06x
32M 193.8 182.2 161.9 234.5 1.06x 1.20x 0.83x
B200
Size NCCL AOT JIT FI AOT/NCCL JIT/NCCL FI/NCCL
4K 23.5 6.5 3.3 3.9 3.61x 7.18x 6.05x
16K 24.3 6.7 3.4 4.0 3.61x 7.23x 6.05x
64K 26.2 7.0 3.7 4.1 3.76x 7.10x 6.45x
128K 28.0 7.2 3.9 4.3 3.86x 7.12x 6.50x
192K 28.7 7.5 4.1 4.5 3.83x 6.93x 6.36x
256K 28.4 7.9 4.4 4.9 3.59x 6.42x 5.83x
384K 29.0 9.9 5.2 5.5 2.92x 5.56x 5.27x
512K 28.7 17.7 5.8 6.1 1.62x 4.93x 4.70x
640K 30.5 17.8 6.7 7.1 1.72x 4.55x 4.29x
768K 30.9 17.9 7.5 7.7 1.72x 4.14x 3.99x
896K 30.7 18.0 7.9 8.2 1.70x 3.91x 3.73x
1M 31.1 18.2 8.5 8.8 1.71x 3.67x 3.53x
2M 36.2 26.6 14.0 14.8 1.36x 2.59x 2.45x
3M 41.8 35.2 18.3 21.3 1.19x 2.29x 1.97x
4M 48.1 43.3 21.9 27.4 1.11x 2.20x 1.76x
8M 58.0 74.0 31.3 56.6 0.78x 1.86x 1.02x
16M 88.8 131.0 49.8 78.4 0.68x 1.78x 1.13x
32M 131.3 252.3 87.1 143.4 0.52x 1.51x 0.92x

benchmark result here @yuan-luo

@yuan-luo
Copy link
Copy Markdown
Collaborator

Some performance results for TP=4 on H200/B200

benchmark result here @yuan-luo

@DarkSharpness Awesome benchmark result. Could we put it in the PR description benchmark and profiling column?

"""
# HARDCODED: opt-in flag for v2 JIT all-reduce.
# Set SGLANG_USE_JIT_ALL_REDUCE=1 to enable.
if _is_cuda and get_bool_env_var("SGLANG_USE_JIT_ALL_REDUCE", default="false"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is stable and outperforms than other ARs, can we set it default true?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will enable this in another PR. This PR is too large which involves some clean up in parallel states. We should ensure the correctness of that part first.

# NOTE: This result is based on benchmarks on H200 GPUs
THRESHOLD_2_SHOT_MAP = {
2: ModeConfig(2 * MB, INF),
3: ModeConfig(512 * MB, 512 * KB),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this 512 * MB expected? Seems differ too much.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be KB

@BBuf BBuf merged commit 2dd9196 into sgl-project:main Mar 20, 2026
141 of 166 checks passed
@DarkSharpness DarkSharpness deleted the jit_custom_all_reduce branch March 20, 2026 10:25
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…gl-project#19880)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…gl-project#19880)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
…gl-project#19880)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…gl-project#19880)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…gl-project#19880)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants