Skip to content

Commit 56bf9d0

Browse files
committed
Introducing a new AR strategy that makes use if NCCL symmetric memory and NEW NCCL device API to use NCCL to fuse RMS Norm with AllReduce.
Signed-off-by: Ludwig Schneider <[email protected]>
1 parent b181568 commit 56bf9d0

File tree

20 files changed

+1803
-28
lines changed

20 files changed

+1803
-28
lines changed

cpp/tensorrt_llm/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,4 @@ add_subdirectory(groupRmsNormKernels)
8686
add_subdirectory(llama4MinLatencyKernels)
8787
add_subdirectory(dsv3MinLatencyKernels)
8888
add_subdirectory(causalConv1d)
89+
add_subdirectory(nccl_device)

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum class AllReduceStrategyType : int8_t
5858
LOWPRECISION = 6,
5959
MNNVL = 7,
6060
NCCL_SYMMETRIC = 8,
61+
NCCL_DEVICE = 9,
6162
};
6263

6364
enum class AllReduceStrategyConfig : int8_t
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# CMakeLists.txt for nccl_device
2+
# This directory contains CUDA kernels and host launcher code
3+
4+
# Enable CUDA
5+
enable_language(CUDA)
6+
7+
# Create CUDA library
8+
add_library(tensorrt_llm_nccl_device
9+
config.cu
10+
)
11+
12+
# Set properties for the CUDA library
13+
set_target_properties(tensorrt_llm_nccl_device PROPERTIES
14+
CUDA_STANDARD 17
15+
CUDA_SEPARABLE_COMPILATION ON
16+
POSITION_INDEPENDENT_CODE ON
17+
)
18+
19+
# Include directories
20+
target_include_directories(tensorrt_llm_nccl_device PUBLIC
21+
${CMAKE_CURRENT_SOURCE_DIR}
22+
${CMAKE_CURRENT_SOURCE_DIR}/../..
23+
)
24+
25+
# Link libraries
26+
target_link_libraries(tensorrt_llm_nccl_device
27+
tensorrt_llm_common
28+
)
29+
30+
# Install target
31+
install(TARGETS tensorrt_llm_nccl_device
32+
LIBRARY DESTINATION lib
33+
ARCHIVE DESTINATION lib
34+
)

cpp/tensorrt_llm/kernels/nccl_device/config.cu

Lines changed: 389 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*************************************************************************
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* See LICENSE.txt for license information
5+
************************************************************************/
6+
7+
#ifndef TRTLLM_NCCL_DEVICE_CONFIG_H
8+
#define TRTLLM_NCCL_DEVICE_CONFIG_H
9+
10+
#include <iostream>
11+
#include <sstream>
12+
#include <string>
13+
#include <typeinfo>
14+
#include <vector>
15+
#include <cassert>
16+
#include <cuda_runtime.h>
17+
#include "nccl.h"
18+
#include "nccl_device.h"
19+
#include "vector_types.h"
20+
#include "constants.h"
21+
#include "tensorrt_llm/common/assert.h"
22+
#include "tensorrt_llm/common/dataType.h"
23+
#include "tensorrt_llm/runtime/iBuffer.h"
24+
25+
namespace tensorrt_llm::kernels::nccl_device {
26+
27+
// Kernel launch information helper class
28+
class LaunchConfig {
29+
public:
30+
const int hidden_dim;
31+
const int num_tokens;
32+
const int nRanks;
33+
const int rank;
34+
const bool useResidual;
35+
const bool useBias;
36+
const bool unshardResidualOut;
37+
protected:
38+
int token_per_rank;
39+
int start_token;
40+
bool valid;
41+
int threadsPerBlock;
42+
int unrollFactor;
43+
44+
std::pair<int, int> pickLaunchCombo(const std::vector<std::pair<int,int>>& options);
45+
46+
public:
47+
// Constructor with dynamic block size calculation
48+
LaunchConfig(const int hidden_dim, const int num_tokens, const int rank, const int nRanks, bool useResidual, bool useBias, bool unshardResidualOut);
49+
50+
inline int getThreadsPerBlock() const { return this->threadsPerBlock; }
51+
int getUnrollFactor() const{ return this->unrollFactor;}
52+
virtual bool getValid()const=0;
53+
int getBlocksPerRank() const {return this->token_per_rank;}
54+
int getStartToken()const {return this->start_token;}
55+
virtual int getElementsPerVector()const = 0;
56+
virtual nvinfer1::DataType getDataType()const =0;
57+
virtual void* getKernelPtr() const = 0;
58+
virtual bool isValidConfig(int threadsPerBlock, int unrollFactor, int blocksPerRank) const = 0;
59+
60+
// Launcher functions as member functions
61+
void launchRMSNorm(ncclWindow_t inWindow, ncclWindow_t outWindow,
62+
const void* const residual, ncclWindow_t residualOutWindow,
63+
const void* const weight, const void* const bias,
64+
ncclDevComm devComm, const float eps, cudaStream_t stream) const;
65+
66+
bool supportsMultimem() const;
67+
68+
protected:
69+
// Pure virtual launch function that must be implemented by derived classes
70+
virtual void launchKernel(ncclWindow_t inWindow, ncclWindow_t outWindow,
71+
const void* const residual, ncclWindow_t residualOutWindow,
72+
const void* const weight, const void* const bias,
73+
ncclDevComm devComm, const float eps, cudaStream_t stream) const = 0;
74+
75+
// Logging output
76+
std::string getLoggingString() const;
77+
};
78+
79+
80+
// Kernel launch information helper class
81+
template<typename T>
82+
class TypedLaunchConfig : public LaunchConfig {
83+
private:
84+
nvinfer1::DataType mType;
85+
86+
// Private templated helper function to get kernel pointer for specific unroll factor
87+
template<int Nunroll>
88+
void* getKernelPtrForUnroll() const;
89+
90+
// Private helper function to get kernel pointer for any unroll factor
91+
void* getKernelPtrForUnrollFactor(int unrollFactor) const;
92+
93+
// Private helper function to launch kernel for any unroll factor
94+
void launchKernelForUnrollFactor(ncclWindow_t inWindow, ncclWindow_t outWindow,
95+
const void* const residual, ncclWindow_t residualOutWindow,
96+
const void* const weight, const void* const bias,
97+
ncclDevComm devComm, const float eps, cudaStream_t stream,
98+
const dim3& gridDim, const dim3& blockDim, const size_t sharedMemSize) const;
99+
100+
// Private templated helper function to launch kernel for specific unroll factor
101+
template<int Nunroll>
102+
void launchKernelForUnrollImpl(ncclWindow_t inWindow, ncclWindow_t outWindow,
103+
const void* const residual, ncclWindow_t residualOutWindow,
104+
const void* const weight, const void* const bias,
105+
ncclDevComm devComm, const float eps, cudaStream_t stream,
106+
const dim3& gridDim, const dim3& blockDim, const size_t sharedMemSize,
107+
bool useResidual, bool useBias, bool unshardResidualOut,
108+
int startToken, int hiddenDim, int numTokens) const;
109+
110+
public:
111+
using TN = typename VectorType<T>::type;
112+
constexpr static int elementsPerVector = sizeof(TN) / sizeof(T);
113+
public:
114+
115+
virtual int getElementsPerVector() const {return this->elementsPerVector;}
116+
virtual void* getKernelPtr() const override { return getKernelPtrForUnrollFactor(this->unrollFactor); }
117+
virtual bool isValidConfig(int threadsPerBlock, int unrollFactor, int blocksPerRank) const override;
118+
119+
// Launch function that handles all the type-specific logic internally
120+
virtual void launchKernel(ncclWindow_t inWindow, ncclWindow_t outWindow,
121+
const void* const residual, ncclWindow_t residualOutWindow,
122+
const void* const weight, const void* const bias,
123+
ncclDevComm devComm, const float eps, cudaStream_t stream) const override;
124+
125+
// Constructor with dynamic block size calculation
126+
TypedLaunchConfig(const int hidden_dim, const int num_tokens, const int rank, const int nRanks, bool useResidual, bool useBias, bool unshardResidualOut);
127+
nvinfer1::DataType getDataType()const{return tensorrt_llm::runtime::TRTDataType<T>::value;}
128+
virtual bool getValid()const{ return this->valid;}
129+
130+
};
131+
132+
std::shared_ptr<LaunchConfig> makeLaunchConfig(nvinfer1::DataType dataType, const int hidden_dim, const int num_tokens, const int rank, const int nRanks, bool useResidual, bool useBias, bool unshardResidualOut);
133+
134+
} // namespace tensorrt_llm::kernels::nccl_device
135+
136+
#endif // TRTLLM_NCCL_DEVICE_CONFIG_H
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef TRTLLM_NCCL_DEVICE_CONSTANTS_H
18+
#define TRTLLM_NCCL_DEVICE_CONSTANTS_H
19+
20+
#include <cstdint>
21+
22+
namespace tensorrt_llm::kernels::nccl_device {
23+
24+
// CUDA and kernel constants
25+
constexpr int kWarpSize = 32;
26+
constexpr int kMaxThreadsPerBlock = 256; // Maximum block size configurable for performance. Corresponse to shared memory requirement for cub::BlockReduce
27+
constexpr int kMinThreadsPerBlock = kWarpSize; // Minimum block size is a warp.
28+
constexpr int kMaxUnrollFactor = 8; // We require manual instantiation and switches. Changing the number is not good enough, see launcher function for details
29+
constexpr bool kUnshardCompletely = true;
30+
} // namespace tensorrt_llm::kernels::nccl_device
31+
32+
#endif // TRTLLM_NCCL_DEVICE_CONSTANTS_H

0 commit comments

Comments
 (0)