Skip to content

Commit 08bf35d

Browse files
committed
first step new allreduce
Signed-off-by: Ludwig Schneider <[email protected]> better UB init handling Signed-off-by: Ludwig Schneider <[email protected]> accept multiple strategies Signed-off-by: Ludwig Schneider <[email protected]> test to debug mnnvl Signed-off-by: Ludwig Schneider <[email protected]> rebasing and addressing comments Signed-off-by: Ludwig Schneider <[email protected]> remove unneeded type decl Signed-off-by: Ludwig Schneider <[email protected]>
1 parent 9367328 commit 08bf35d

File tree

11 files changed

+178
-80
lines changed

11 files changed

+178
-80
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text
1212
docs/source/blogs/media/tech_blog10_baseline_performance_detail.png filter=lfs diff=lfs merge=lfs -text
1313
docs/source/blogs/media/tech_blog10_full_strategy_performance.png filter=lfs diff=lfs merge=lfs -text
1414
docs/source/blogs/media/tech_blog10_context_wait_performance.png filter=lfs diff=lfs merge=lfs -text
15+
ATTRIBUTIONS-CPP.md filter=lfs diff=lfs merge=lfs -text

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ repos:
13951395
exclude: |
13961396
(?x)^(.*cubin.cpp | .*cubin.h)$
13971397
- id: check-merge-conflict
1398+
exclude: ^ATTRIBUTIONS-CPP\.md$
13981399
- id: check-symlinks
13991400
- id: detect-private-key
14001401
- id: end-of-file-fixer

ATTRIBUTIONS-CPP.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:7c56252c4d635377c3202c34f8f049053ceb9567fc1f243f4fa5ba91d762176b
3+
size 818494

cpp/tensorrt_llm/kernels/nccl_device/CMakeLists.txt

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
3+
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6+
# use this file except in compliance with the License. You may obtain a copy of
7+
# the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
# License for the specific language governing permissions and limitations under
15+
# the License.
16+
#
17+
118
# CMakeLists.txt for nccl_device This directory contains CUDA kernels and host
219
# launcher code
320

@@ -20,9 +37,3 @@ target_include_directories(
2037

2138
# Link libraries
2239
target_link_libraries(tensorrt_llm_nccl_device tensorrt_llm_common)
23-
24-
# Install target
25-
install(
26-
TARGETS tensorrt_llm_nccl_device
27-
LIBRARY DESTINATION lib
28-
ARCHIVE DESTINATION lib)

cpp/tensorrt_llm/kernels/nccl_device/config.cu

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1-
/*************************************************************************
2-
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
33
*
4-
* See LICENSE.txt for license information
5-
************************************************************************/
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
616

717
#include "config.h"
818
#include "nccl.h"
919
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0)
10-
#include "kernels.h"
20+
#include "kernels.cuh"
1121
#endif
1222
#include "tensorrt_llm/common/cudaUtils.h"
1323
#include "tensorrt_llm/common/envUtils.h"
@@ -176,19 +186,9 @@ bool TypedLaunchConfig<T>::isValidConfig(int threadsPerBlock, int unrollFactor)
176186
{
177187
// Get CUDA device properties
178188
int dev = -1;
179-
cudaError_t cudaStatus = cudaGetDevice(&dev);
180-
if (cudaStatus != cudaSuccess)
181-
{
182-
TLLM_LOG_ERROR("Failed to get CUDA device: " + std::string(cudaGetErrorString(cudaStatus)));
183-
return false;
184-
}
189+
TLLM_CUDA_CHECK(cudaGetDevice(&dev));
185190
cudaDeviceProp deviceProp;
186-
cudaStatus = cudaGetDeviceProperties(&deviceProp, dev);
187-
if (cudaStatus != cudaSuccess)
188-
{
189-
TLLM_LOG_ERROR("Failed to get CUDA device properties: " + std::string(cudaGetErrorString(cudaStatus)));
190-
return false;
191-
}
191+
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, dev));
192192

193193
// Check threads per block limits
194194
if (threadsPerBlock <= 0 || threadsPerBlock > deviceProp.maxThreadsPerBlock)
@@ -217,13 +217,7 @@ bool TypedLaunchConfig<T>::isValidConfig(int threadsPerBlock, int unrollFactor)
217217

218218
// Get actual register and shared memory usage from the kernel
219219
cudaFuncAttributes funcAttrib;
220-
cudaError_t attrStatus = cudaFuncGetAttributes(&funcAttrib, reinterpret_cast<void const*>(kernelPtr));
221-
if (attrStatus != cudaSuccess)
222-
{
223-
TLLM_LOG_WARNING(
224-
"Failed to get kernel attributes for validation: " + std::string(cudaGetErrorString(attrStatus)));
225-
return false;
226-
}
220+
TLLM_CUDA_CHECK(cudaFuncGetAttributes(&funcAttrib, reinterpret_cast<void const*>(kernelPtr)));
227221

228222
// Check register usage
229223
int const totalRegistersPerBlock = funcAttrib.numRegs * threadsPerBlock;

cpp/tensorrt_llm/kernels/nccl_device/config.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1-
/*************************************************************************
2-
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
33
*
4-
* See LICENSE.txt for license information
5-
************************************************************************/
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
616

717
#ifndef TRTLLM_NCCL_DEVICE_CONFIG_H
818
#define TRTLLM_NCCL_DEVICE_CONFIG_H
@@ -115,8 +125,6 @@ template <typename T>
115125
class TypedLaunchConfig : public LaunchConfig
116126
{
117127
private:
118-
nvinfer1::DataType mType;
119-
120128
// Private templated helper function to get kernel pointer for specific unroll factor
121129
template <int Nunroll>
122130
void* getKernelPtrForUnroll() const;

cpp/tensorrt_llm/kernels/nccl_device/kernels.h renamed to cpp/tensorrt_llm/kernels/nccl_device/kernels.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifndef TRTLLM_NCCL_DEVICE_KERNELS_H
18-
#define TRTLLM_NCCL_DEVICE_KERNELS_H
17+
#ifndef TRTLLM_NCCL_DEVICE_KERNELS_CUH
18+
#define TRTLLM_NCCL_DEVICE_KERNELS_CUH
1919

2020
#include "constants.h"
2121
#include "multimem.h"
@@ -252,4 +252,4 @@ __global__ void fusedAllReduceRMSNormKernel(ncclWindow_t input_win, ncclWindow_t
252252

253253
} // namespace tensorrt_llm::kernels::nccl_device
254254

255-
#endif // TRTLLM_NCCL_DEVICE_KERNELS_H
255+
#endif // TRTLLM_NCCL_DEVICE_KERNELS_CUH

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -558,19 +558,14 @@ class AllreduceOp
558558
norm_weight.value().data_ptr(), nullptr, devComm, mEps, stream);
559559
return {norm_out, residual_out};
560560
}
561-
else
562-
{
563-
// Fall back to old strategy with warning
564-
TLLM_LOG_WARNING(
565-
"NCCL device Fused AR not supported for data type %d, hidden size %d & %d nRanks on current "
566-
"architecture. Falling back to standard allreduce + separate RMSNorm.",
567-
static_cast<int>(mType), hidden_size, nRanks);
568-
569-
goto default_case;
570-
}
561+
// Fall back to old strategy with warning
562+
TLLM_LOG_WARNING(
563+
"NCCL device Fused AR not supported for data type %d, hidden size %d & %d nRanks on current "
564+
"architecture. Falling back to standard allreduce + separate RMSNorm.",
565+
static_cast<int>(mType), hidden_size, nRanks);
571566
}
567+
// Intentional fallthrough to default
572568
default:
573-
default_case:
574569
NCCLCHECK(ncclAllReduce(
575570
ub_buffer0.addr, ub_buffer1.addr, size, (*getDtypeMap())[mType], ncclSum, *rawComm, stream));
576571
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out);

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
275275
package_data={
276276
'tensorrt_llm': package_data,
277277
},
278-
license_files=get_license(),
278+
license_files=["LICENSE", "ATTRIBUTIONS-CPP.md"],
279279
entry_points={
280280
'console_scripts': [
281281
'trtllm-build=tensorrt_llm.commands.build:main',

0 commit comments

Comments
 (0)