-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] Add NCCL device kernels for AR+RMS fusion #7910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[None][feat] Add NCCL device kernels for AR+RMS fusion #7910
Conversation
6e1c6cd to
39b2e16
Compare
📝 WalkthroughWalkthroughAdds a new NCCL_DEVICE all-reduce strategy and device-side fusion module (nccl_device): build targets, CUDA kernels, multimem/vector helpers, launch-config factory, runtime dispatch and allocator support, Python/enum plumbing, benchmarks, and tests. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PY as Python API
participant ME as ModelEngine
participant OP as AllreduceOp (C++)
participant UB as NCCLUserBufferAllocator
participant KD as nccl_device::LaunchConfig
participant KRN as nccl_device Kernel
participant NCCL as NCCL (host+device)
PY->>ME: request AllReduce with strategy="NCCL_DEVICE"
ME->>OP: execute AllReduce (inputs, fusion attrs)
OP->>UB: getNCCLDevComm(numBarriers)
UB->>NCCL: resolve/create device communicator
UB-->>OP: ncclDevComm
OP->>UB: getCachedNCCLDeviceLaunchConfig(dtype, dims, flags)
UB-->>OP: LaunchConfig (KD)
OP->>KRN: KD.launchRMSNorm(..., devComm, stream)
KRN->>NCCL: device-side allreduce (multimem ld/st)
KRN-->>OP: outputs written
OP-->>ME: return tensors
ME-->>PY: result
sequenceDiagram
autonumber
participant OP as AllreduceOp (C++)
participant SYM as UB Symmetric Buffers
participant F as Fallback Path
OP->>SYM: Verify symmetric UB buffer
alt buffer missing
OP->>SYM: Create symmetric input and copy data
end
alt device fusion unsupported or invalid
OP->>F: fallbackRunSubsequentOps(...)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@CodeRabbit review |
6449ab6 to
3f1e163
Compare
I have a very strong suspicion: In general it is difficult to optimize a single kernel for small and large message sizes at the same time, which is why we want to use different strategies for different situations. ncclAllReduce does that automatically for us. This may be another reason to leave NCCL_SYMMETRIC and NCCL_DEVICE as separate strategies, for different message sizes. I am collecting data on this and will update the analysis with additional info. |
6dc8812 to
0f55f3f
Compare
| return false; | ||
| } | ||
|
|
||
| // 6. Query actual kernel resource usage from kernel pointer for the specific unroll factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there 4 and 5 steps too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the numbering was out of order.
I removed the numbering, since it is isn't necessary and brittle change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numbering is still incorrect.
| target_link_libraries(tensorrt_llm_nccl_device tensorrt_llm_common) | ||
|
|
||
| # Install target | ||
| install( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required? Can we only link statically to the TRT-LLM library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary. I remove this section.
| @@ -0,0 +1,516 @@ | |||
| /************************************************************************* | |||
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make the license header consistent with other files.
| { | ||
| int local_sms = 1; | ||
| int dev = -1; | ||
| cudaError_t cudaStatus = cudaGetDevice(&dev); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use TLLM_CUDA_CHECK instead here and everywhere?
| cudaError_t cudaStatus = cudaGetDevice(&dev); | |
| TLLM_CUDA_CHECK(cudaGetDevice(&dev)); |
| { | ||
| // Get CUDA device properties | ||
| int dev = -1; | ||
| cudaError_t cudaStatus = cudaGetDevice(&dev); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
| return false; | ||
| } | ||
|
|
||
| // 6. Query actual kernel resource usage from kernel pointer for the specific unroll factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numbering is still incorrect.
| return false; | ||
| } | ||
|
|
||
| // 8. Check occupancy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix numbering
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * See LICENSE.txt for license information | ||
| ************************************************************************/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update license please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the license of the 2 files. And double checked that the other files are OK too.
| goto default_case; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nv-lschneider is it possible to address this comment?
… and NEW NCCL device API to use NCCL to fuse RMS Norm with AllReduce. Signed-off-by: Ludwig Schneider <[email protected]>
…) (NVIDIA#7900) Signed-off-by: Yan Chunwei <[email protected]> Signed-off-by: Ludwig Schneider <[email protected]> pre-commit changes Signed-off-by: Ludwig Schneider <[email protected]> clang formatting Signed-off-by: Ludwig Schneider <[email protected]> safe guarding NCCL 2.27 build Signed-off-by: Ludwig Schneider <[email protected]> fixing precommit formatting Signed-off-by: Ludwig Schneider <[email protected]> most of code rabbit comments Signed-off-by: Ludwig Schneider <[email protected]> adding missing semi-colon Signed-off-by: Ludwig Schneider <[email protected]> removing unused comment lines Signed-off-by: Ludwig Schneider <[email protected]> Clarifying the test on how to compre residual chunked and unchunked. Signed-off-by: Ludwig Schneider <[email protected]> fixing pre-commit Signed-off-by: Ludwig Schneider <[email protected]> fixing pre-commit Signed-off-by: Ludwig Schneider <[email protected]> fixing missing variable, rebase complete and tested Signed-off-by: Ludwig Schneider <[email protected]> using a grid stride loop with less blocks launched for large message sizes Signed-off-by: Ludwig Schneider <[email protected]> using functioning grid stride loop for NCCL_DEVICE. It helps with better performance at larger message sizes Signed-off-by: Ludwig Schneider <[email protected]> initial oneshot implementation Signed-off-by: Ludwig Schneider <[email protected]> minor tweaks to include one shot fixes Signed-off-by: Ludwig Schneider <[email protected]> enabling grid stride loop, but no perf benefit. Signed-off-by: Ludwig Schneider <[email protected]> addressing review feedback Signed-off-by: Ludwig Schneider <[email protected]> fix formatting Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]> better UB init handling Signed-off-by: Ludwig Schneider <[email protected]> accept multiple strategies Signed-off-by: Ludwig Schneider <[email protected]> test to debug mnnvl Signed-off-by: Ludwig Schneider <[email protected]> rebasing and addressing comments Signed-off-by: Ludwig Schneider <[email protected]> remove unneeded type decl Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
9ead5ce to
ca6dd60
Compare
Signed-off-by: Ludwig Schneider <[email protected]>
nv-lschneider
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rebased the code and addressesd your comments. Thx for the patience.
Rebasing takes awhile.
| goto default_case; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, using intentional fallthrough without goto instead.
(If we support more cases, we will have refactor the default case out.)
| k_chunk_size = a.size(1) // tensor_parallel_size | ||
| b.size(0) // tensor_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing unused command.
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * See LICENSE.txt for license information | ||
| ************************************************************************/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the license of the 2 files. And double checked that the other files are OK too.
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
a7b677d to
6cc2722
Compare
Tabrizian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, merging is blocked until the PyTorch container upgrades to 2.28 version.
Summary by CodeRabbit
New Features
Tests
Documentation
Chores
Description
The MR introduces a new kernel launch mechanism to support kernels with the NCCL device API.
It implements 1 kernel to start with: RESIDUAL_RMS_NORM for fp16 types.
This new kernel is meant to replace/enhance the performance of AllReduce using the stable NCCL API.
This is the first kernel of potentially more variations for best performance. The default AR selection strategy is not impacted yet.
It is designed to be low latency for small to medium message sizes.
The MR uses the existing NCCLUBAllocator and extends it to hold necessary persistent resources like NCCL registered memory windows and device communicators.
The AllReduce Operation is implemented as a new AllReduceStrategy and launched from the AllReduce plugin.
cpp/tensorrt_llm/thop/allreduceOP.cpp.It launches its own new kernels at
cpp/tensorrt_llm/kernels/nccl_device.The kernel itself is highly templated to be flexible for future demands without impeding runtime performance.
This MR implements the new kernel in a two-shot / fp16 implementation first.
It is already competitive in this form, however after adoption of this kernel further modifications and additions can be included.
support in the future.
Test Coverage
tests/unittest/_torch/multi_gpu/test_nccl_device.pytests/microbenchmark/all_reduce.pyThe microbenchmark has been updated slightly. It includes now the new strategy and optionally also UB for comparison.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
Important Caveat
This change requires NCCL 2.28 to run successfully.
SInce the current dev container of TRT-LLM does not use 2.28 yet, I would like to gather some feedback before 2.28 becomes available.
A real test will only be possible when version 2.28 is included in the dev container.