diff --git a/.codespellignore b/.codespellignore new file mode 100644 index 00000000000..f1a24b39f52 --- /dev/null +++ b/.codespellignore @@ -0,0 +1 @@ +commIter diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a75e55b79d3..ac9278cfe09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1442,7 +1442,7 @@ repos: additional_dependencies: - tomli # add ignore words list - args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"] + args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*", "-I", ".codespellignore"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.4 hooks: diff --git a/cpp/tensorrt_llm/kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/CMakeLists.txt index 74680318170..03140c4e478 100644 --- a/cpp/tensorrt_llm/kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/CMakeLists.txt @@ -86,3 +86,4 @@ add_subdirectory(groupRmsNormKernels) add_subdirectory(llama4MinLatencyKernels) add_subdirectory(dsv3MinLatencyKernels) add_subdirectory(causalConv1d) +add_subdirectory(nccl_device) diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index c96a1b30649..173c2ffabeb 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -58,6 +58,7 @@ enum class AllReduceStrategyType : int8_t LOWPRECISION = 6, MNNVL = 7, NCCL_SYMMETRIC = 8, + NCCL_DEVICE = 9, }; enum class AllReduceStrategyConfig : int8_t @@ -119,6 +120,7 @@ inline std::ostream& operator<<(std::ostream& os, AllReduceStrategyType op) case AllReduceStrategyType::LOWPRECISION: os << "LOWPRECISION"; break; case AllReduceStrategyType::MNNVL: os << "MNNVL"; break; case AllReduceStrategyType::NCCL_SYMMETRIC: os << "NCCL_SYMMETRIC"; break; + case AllReduceStrategyType::NCCL_DEVICE: os << "NCCL_DEVICE"; break; } return os; } @@ -130,6 +132,15 @@ inline std::string toString(AllReduceStrategyType op) return oss.str(); } +// Helper function to determine if a strategy should skip topology detection +// These strategies manage connectivity internally +inline bool shouldSkipTopologyDetection(AllReduceStrategyType strategy) +{ + return (strategy == AllReduceStrategyType::NCCL || strategy == AllReduceStrategyType::NCCL_SYMMETRIC + || strategy == AllReduceStrategyType::NCCL_DEVICE || strategy == AllReduceStrategyType::UB + || strategy == AllReduceStrategyType::MNNVL); +} + struct AllReduceFusionParams { AllReduceFusionParams() diff --git a/cpp/tensorrt_llm/kernels/nccl_device/CMakeLists.txt b/cpp/tensorrt_llm/kernels/nccl_device/CMakeLists.txt new file mode 100644 index 00000000000..631303638f4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/CMakeLists.txt @@ -0,0 +1,39 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# + +# CMakeLists.txt for nccl_device This directory contains CUDA kernels and host +# launcher code + +# Enable CUDA +enable_language(CUDA) + +# Create CUDA library +add_library(tensorrt_llm_nccl_device config.cu) + +# Set properties for the CUDA library +set_target_properties( + tensorrt_llm_nccl_device + PROPERTIES CUDA_STANDARD 17 CUDA_SEPARABLE_COMPILATION ON + POSITION_INDEPENDENT_CODE ON) + +# Include directories +target_include_directories( + tensorrt_llm_nccl_device PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../..) + +# Link libraries +target_link_libraries(tensorrt_llm_nccl_device tensorrt_llm_common) diff --git a/cpp/tensorrt_llm/kernels/nccl_device/config.cu b/cpp/tensorrt_llm/kernels/nccl_device/config.cu new file mode 100644 index 00000000000..774108652c4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/config.cu @@ -0,0 +1,565 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "nccl.h" +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) +#include "kernels.cuh" +#endif +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include "vector_types.h" +#include +#include +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +std::pair LaunchConfig::pickLaunchCombo(std::vector> const& options) +{ + return options.at(0); // Experimenting found that using less unroll and more threads per block is better +} + +LaunchConfig::LaunchConfig(int const hidden_dim, int const num_tokens, int const rank, int const nRanks, + bool useResidual, bool useBias, int const num_sms) + : hidden_dim(hidden_dim) + , num_tokens(num_tokens) + , rank(rank) + , nRanks(nRanks) + , useResidual(useResidual) + , useBias(useBias) + , oneShot(false) + , token_per_rank(-1) + , start_token(-1) + , num_sms(num_sms) + , valid(false) + , threadsPerBlock(0) + , unrollFactor(0) +{ + // No grid-stride + int const base_tokens = num_tokens / nRanks; + this->num_sms = base_tokens; + int const remainder = num_tokens % nRanks; + if (remainder > 0) + this->num_sms += 1; + + // TODO hard coded value for now. Maybe some tuning possible + if (num_tokens <= 32) + this->oneShot = true; + + if (this->oneShot) + { + // In one shot mode, each rank processes all tokens + this->token_per_rank = num_tokens; + this->start_token = 0; + this->num_sms = num_tokens; + } + else + { + // Distribute tokens across ranks: first 'remainder' ranks get one extra token + this->token_per_rank = base_tokens + (rank < remainder ? 1 : 0); + this->start_token = rank * base_tokens + std::min(rank, remainder); + } + + auto maxCTAEnv = tensorrt_llm::common::getIntEnv("TLLM_NCCL_DEVICE_AR_RMS_MAX_CTA"); + if (maxCTAEnv.has_value()) + { + if (maxCTAEnv.value() > 0) + this->num_sms = maxCTAEnv.value(); + else + { + TLLM_LOG_WARNING("TLLM_NCCL_DEVICE_AR_RMS_MAX_CTA was detected as <= 0 and is ignored."); + } + } +} + +std::string LaunchConfig::getLoggingString() const +{ + std::ostringstream oss; + if (this->valid) + { + oss << "Launching Kernel: NCCL fused AR kernel!\n"; + } + else + { + oss << "Unable to Launch Kernel: NCCL fused AR kernel "; + } + oss << "\tConfiguration:\n"; + oss << "\t\t ThreadsPerBlock: " << this->getThreadsPerBlock() << "\n"; + oss << "\t\t UnrollFactor: " << this->getUnrollFactor() << "\n"; + oss << "\t\t TokensPerRank: " << this->token_per_rank << "\n"; + oss << "\t\t NumSMs: " << this->getNumSMs() << "\n"; + oss << "\t\t VectorInfo: " << this->getElementsPerVector() << "\n"; + oss << "\t\t HiddenDim: " << this->getElementsPerVector() * this->getUnrollFactor() * this->getThreadsPerBlock() + << " = " << this->hidden_dim << "\n"; + oss << "\t\t NumTokens: " << this->num_tokens << "\n"; + oss << "\t\t StartToken: " << this->getStartToken() << "\n"; + + return oss.str(); +} + +// Template class implementation +template +TypedLaunchConfig::TypedLaunchConfig(int const hidden_dim, int const num_tokens, int const rank, int const nRanks, + bool useResidual, bool useBias, int const num_sms) + : LaunchConfig(hidden_dim, num_tokens, rank, nRanks, useResidual, useBias, num_sms) +{ + + // Calculate optimal block size to achieve better coverage + int const maxThreadsPerBlock = kMaxThreadsPerBlock; // Maximum allowed block size + int const minThreadsPerBlock = kMinThreadsPerBlock; // Minimum block size (warp size) + + std::vector> valid_launch_combo; + + // Try to find a block size that gives optimal coverage + for (int testThreadsPerBlock = maxThreadsPerBlock; testThreadsPerBlock >= minThreadsPerBlock; + testThreadsPerBlock -= minThreadsPerBlock) + { + for (int testUnrollFactor = 1; testUnrollFactor <= kMaxUnrollFactor; testUnrollFactor += 1) + { + size_t const elementsProcessedPerBlock = elementsPerVector * testUnrollFactor * testThreadsPerBlock; + + if (elementsProcessedPerBlock == hidden_dim) + { + // Validate that this configuration can actually be launched + if (isValidConfig(testThreadsPerBlock, testUnrollFactor)) + { + valid_launch_combo.push_back(std::make_pair(testThreadsPerBlock, testUnrollFactor)); + } + } + } + } + + if (valid_launch_combo.size() > 0) + { + std::pair optimal_launch_combo = pickLaunchCombo(valid_launch_combo); + + // Set the calculated optimal values + this->threadsPerBlock = optimal_launch_combo.first; + this->unrollFactor = optimal_launch_combo.second; + this->valid = true; + } +} + +std::shared_ptr makeLaunchConfig(nvinfer1::DataType dataType, int const hidden_dim, int const num_tokens, + int const rank, int const nRanks, bool useResidual, bool useBias, int const num_sms) +{ + switch (dataType) + { + case nvinfer1::DataType::kHALF: + return std::make_shared>( + hidden_dim, num_tokens, rank, nRanks, useResidual, useBias, num_sms); + case nvinfer1::DataType::kBF16: + return std::make_shared>( + hidden_dim, num_tokens, rank, nRanks, useResidual, useBias, num_sms); + case nvinfer1::DataType::kFLOAT: + return std::make_shared>( + hidden_dim, num_tokens, rank, nRanks, useResidual, useBias, num_sms); + default: TLLM_THROW("Unimplemented data type for fused NCCL AllReduce launches."); + } + return nullptr; +} + +// Explicit template instantiations +template class TypedLaunchConfig; +template class TypedLaunchConfig<__nv_bfloat16>; +template class TypedLaunchConfig; + +// Implementation of launch configuration validation +template +bool TypedLaunchConfig::isValidConfig(int threadsPerBlock, int unrollFactor) const +{ + // Get CUDA device properties + int dev = -1; + TLLM_CUDA_CHECK(cudaGetDevice(&dev)); + cudaDeviceProp deviceProp; + TLLM_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, dev)); + + // Check threads per block limits + if (threadsPerBlock <= 0 || threadsPerBlock > deviceProp.maxThreadsPerBlock) + { + return false; + } + + // Check warp size alignment + if (threadsPerBlock % deviceProp.warpSize != 0) + { + return false; + } + + // Check unroll factor validity + if (unrollFactor <= 0 || unrollFactor > kMaxUnrollFactor) + { + return false; + } + + // Query actual kernel resource usage from kernel pointer for the specific unroll factor + void* kernelPtr = this->getKernelPtrForUnrollFactor(unrollFactor); + if (kernelPtr == nullptr) + { + return false; + } + + // Get actual register and shared memory usage from the kernel + cudaFuncAttributes funcAttrib; + TLLM_CUDA_CHECK(cudaFuncGetAttributes(&funcAttrib, reinterpret_cast(kernelPtr))); + + // Check register usage + int const totalRegistersPerBlock = funcAttrib.numRegs * threadsPerBlock; + if (totalRegistersPerBlock > deviceProp.regsPerBlock) + { + return false; + } + + // Check shared memory usage + if (funcAttrib.sharedSizeBytes > deviceProp.sharedMemPerBlock) + { + return false; + } + + // Check occupancy + int const warpsPerBlock = threadsPerBlock / deviceProp.warpSize; + int const maxWarpsPerSM = deviceProp.maxThreadsPerMultiProcessor / deviceProp.warpSize; + int const maxBlocksPerSM = deviceProp.maxThreadsPerMultiProcessor / threadsPerBlock; + + if (warpsPerBlock > maxWarpsPerSM) + { + return false; + } + + if (maxBlocksPerSM <= 0) + { + return false; + } + return true; +} + +// Template function implementations +template +template +void* TypedLaunchConfig::getKernelPtrForUnroll() const +{ + using TN = typename VectorType::type; + + void* result = nullptr; + if (oneShot) + { + if (useResidual && useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else if (useResidual && !useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else if (!useResidual && useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + } + else + { + if (useResidual && useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else if (useResidual && !useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else if (!useResidual && useBias) + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + else + { + result = reinterpret_cast(fusedAllReduceRMSNormKernel); + } + } + + return result; +} + +template +void* TypedLaunchConfig::getKernelPtrForUnrollFactor(int unrollFactor) const +{ +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + void* result = nullptr; + switch (unrollFactor) + { + case 1: result = getKernelPtrForUnroll<1>(); break; + case 2: result = getKernelPtrForUnroll<2>(); break; + case 3: result = getKernelPtrForUnroll<3>(); break; + case 4: result = getKernelPtrForUnroll<4>(); break; + case 5: result = getKernelPtrForUnroll<5>(); break; + case 6: result = getKernelPtrForUnroll<6>(); break; + case 7: result = getKernelPtrForUnroll<7>(); break; + case 8: result = getKernelPtrForUnroll<8>(); break; + default: result = nullptr; break; + } + + return result; +#else + return nullptr; +#endif +} + +// Function to launch kernel for any unroll factor (shares logic with getKernelPtrForUnrollFactor) +template +void TypedLaunchConfig::launchKernelForUnrollFactor(ncclWindow_t inWindow, ncclWindow_t outWindow, + void const* const residual, ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, + ncclDevComm devComm, float const eps, cudaStream_t stream) const +{ +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + // Use the same logic as getKernelPtrForUnrollFactor but launch the kernel directly + switch (this->unrollFactor) + { + case 1: + this->launchKernelForUnrollImpl<1>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 2: + this->launchKernelForUnrollImpl<2>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 3: + this->launchKernelForUnrollImpl<3>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 4: + this->launchKernelForUnrollImpl<4>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 5: + this->launchKernelForUnrollImpl<5>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 6: + this->launchKernelForUnrollImpl<6>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 7: + this->launchKernelForUnrollImpl<7>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + case 8: + this->launchKernelForUnrollImpl<8>( + inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); + break; + default: + TLLM_CHECK_WITH_INFO(false, "Invalid unroll factor %d for %s precision. Supported values: 1-8", + this->unrollFactor, typeid(T).name()); + } +#else + TLLM_THROW("NCCL device kernels not available (NCCL version < 2.28). Cannot launch kernel."); +#endif +} + +// Template implementation that shares the exact same logic as getKernelPtrForUnroll +template +template +void TypedLaunchConfig::launchKernelForUnrollImpl(ncclWindow_t inWindow, ncclWindow_t outWindow, + void const* const residual, ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, + ncclDevComm devComm, float const eps, cudaStream_t stream) const +{ +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + using TN = typename VectorType::type; + + // Calculate grid and block dimensions from config members + // Use num_sms for grid dimension to match available hardware parallelism + dim3 const gridDim(this->num_sms, 1, 1); + dim3 const blockDim(this->threadsPerBlock, 1, 1); + size_t const sharedMemSize = 0; + + // Use the exact same logic as getKernelPtrForUnroll but launch the kernel + if (this->oneShot) + { + if (this->useResidual && this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else if (this->useResidual && !this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else if (!this->useResidual && this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + } + else + { + if (this->useResidual && this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else if (this->useResidual && !this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else if (!this->useResidual && this->useBias) + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + else + { + fusedAllReduceRMSNormKernel + <<>>(inWindow, outWindow, static_cast(residual), + residualOutWindow, static_cast(weight), static_cast(bias), this->start_token, + this->hidden_dim, this->token_per_rank, devComm, eps); + } + } +#else + TLLM_THROW("NCCL device kernels not available (NCCL version < 2.28). Cannot launch kernel."); +#endif +} + +// Implementation of launch function that handles all type-specific logic +template +void TypedLaunchConfig::launchKernel(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const +{ +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + using TN = typename VectorType::type; + + // Launch kernel using runtime template parameter selection + launchKernelForUnrollFactor(inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); +#else + TLLM_THROW("NCCL device kernels not available (NCCL version < 2.28). Cannot launch kernel."); +#endif +} + +// Member function implementations +void LaunchConfig::launchRMSNorm(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const +{ + // Input validation + TLLM_CHECK_WITH_INFO(inWindow != nullptr, "NCCL inWindow needs to be initialized."); + TLLM_CHECK_WITH_INFO(outWindow != nullptr, "NCCL outWindow needs to be initialized."); + + TLLM_CHECK_WITH_INFO(eps >= 0.0f, "Epsilon must be non-negative, got %f", eps); + TLLM_CHECK_WITH_INFO(weight != nullptr, "Weight pointer cannot be null"); + TLLM_CHECK_WITH_INFO(residualOutWindow != nullptr, "Residual output pointer cannot be null"); + TLLM_CHECK_WITH_INFO(residual != nullptr, "Residual needs to be a valid pointer"); + TLLM_CHECK_WITH_INFO(this->getValid(), "LaunchConfig invalid"); + TLLM_CHECK_WITH_INFO(bias == nullptr, "we are not supporting a bias here."); + + // Delegate all launch logic to the config class + this->launchKernel(inWindow, outWindow, residual, residualOutWindow, weight, bias, devComm, eps, stream); +} + +// Runtime function to check if multimem is supported for a given data type +bool LaunchConfig::supportsMultimem() const +{ + nvinfer1::DataType dataType = this->getDataType(); + bool isValid = this->getValid(); + + TLLM_LOG_DEBUG("supportsMultimem() called: dataType=%d, nRanks=%d, valid=%d", static_cast(dataType), + this->nRanks, isValid); + +#ifdef ARCH_HAS_MULTIMEM + TLLM_LOG_DEBUG(" ARCH_HAS_MULTIMEM is defined"); + + // Note: 2 ranks are now supported for multimem + TLLM_LOG_DEBUG(" nRanks=%d (multimem supports 2+ ranks)", this->nRanks); + + // Basic types are always supported on SM90+ + switch (dataType) + { + case nvinfer1::DataType::kFLOAT: // float + { + TLLM_LOG_DEBUG(" DataType is FLOAT, checking getValid()=%d", isValid); + return this->getValid(); + } + // Half and BFloat16 with .acc::f32 qualifier (SM90+) +#ifdef ARCH_HAS_MULTIMEM_ACC_F32 + case nvinfer1::DataType::kHALF: // half + { + TLLM_LOG_DEBUG(" DataType is HALF, ARCH_HAS_MULTIMEM_ACC_F32 defined, checking getValid()=%d", isValid); + return this->getValid(); + } + case nvinfer1::DataType::kBF16: // __nv_bfloat16 + { + TLLM_LOG_DEBUG(" DataType is BF16, ARCH_HAS_MULTIMEM_ACC_F32 defined, checking getValid()=%d", isValid); + return this->getValid(); + } +#else + case nvinfer1::DataType::kHALF: // half + case nvinfer1::DataType::kBF16: // __nv_bfloat16 + { + TLLM_LOG_DEBUG(" DataType is HALF/BF16 but ARCH_HAS_MULTIMEM_ACC_F32 NOT defined, returning FALSE"); + return false; + } +#endif // ARCH_HAS_MULTIMEM_ACC_F32 + + // FP8 types with .acc::f16 qualifier (SM100+) +#ifdef ARCH_HAS_MULTIMEM_FP8 + case nvinfer1::DataType::kFP8: // FP8 (either E5M2 or E4M3) + { + TLLM_LOG_DEBUG(" DataType is FP8, ARCH_HAS_MULTIMEM_FP8 defined, checking getValid()=%d", isValid); + return this->getValid(); + } +#else + case nvinfer1::DataType::kFP8: + { + TLLM_LOG_DEBUG(" DataType is FP8 but ARCH_HAS_MULTIMEM_FP8 NOT defined, returning FALSE"); + return false; + } +#endif // ARCH_HAS_MULTIMEM_FP8 + default: + TLLM_LOG_DEBUG(" DataType %d not supported for multimem, returning FALSE", static_cast(dataType)); + return false; + } +#else + TLLM_LOG_DEBUG(" ARCH_HAS_MULTIMEM NOT defined, returning FALSE"); + return false; +#endif // ARCH_HAS_MULTIMEM +} + +} // namespace tensorrt_llm::kernels::nccl_device diff --git a/cpp/tensorrt_llm/kernels/nccl_device/config.h b/cpp/tensorrt_llm/kernels/nccl_device/config.h new file mode 100644 index 00000000000..a35e9c05a86 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/config.h @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRTLLM_NCCL_DEVICE_CONFIG_H +#define TRTLLM_NCCL_DEVICE_CONFIG_H + +#include "constants.h" +#include "nccl.h" +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) +#include "nccl_device.h" +#endif +#if NCCL_VERSION_CODE <= NCCL_VERSION(2, 28, 0) +using ncclDevComm = void*; +#endif +#if NCCL_VERSION_CODE <= NCCL_VERSION(2, 27, 0) +using ncclWindow_t = void*; +#endif + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "vector_types.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +// Kernel launch information helper class +class LaunchConfig +{ +public: + int const hidden_dim; + int const num_tokens; + int const nRanks; + int const rank; + bool const useResidual; + bool const useBias; + +protected: + bool oneShot; + int token_per_rank; + int start_token; + int num_sms; + bool valid; + int threadsPerBlock; + int unrollFactor; + + std::pair pickLaunchCombo(std::vector> const& options); + +public: + // Constructor with dynamic block size calculation + LaunchConfig(int const hidden_dim, int const num_tokens, int const rank, int const nRanks, bool useResidual, + bool useBias, int const num_sms = -1); + + inline int getThreadsPerBlock() const + { + return this->threadsPerBlock; + } + + int getUnrollFactor() const + { + return this->unrollFactor; + } + + int getNumSMs() const + { + return this->num_sms; + } + + virtual bool getValid() const = 0; + + int getBlocksPerRank() const + { + return this->token_per_rank; + } + + int getStartToken() const + { + return this->start_token; + } + + virtual int getElementsPerVector() const = 0; + virtual nvinfer1::DataType getDataType() const = 0; + virtual bool isValidConfig(int threadsPerBlock, int unrollFactor) const = 0; + + // Launcher functions as member functions + void launchRMSNorm(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const; + + bool supportsMultimem() const; + + // Logging output + std::string getLoggingString() const; + +protected: + // Pure virtual launch function that must be implemented by derived classes + virtual void launchKernel(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const + = 0; +}; + +// Kernel launch information helper class +template +class TypedLaunchConfig : public LaunchConfig +{ +private: + // Private templated helper function to get kernel pointer for specific unroll factor + template + void* getKernelPtrForUnroll() const; + + // Private helper function to get kernel pointer for any unroll factor + void* getKernelPtrForUnrollFactor(int unrollFactor) const; + + // Private helper function to launch kernel for any unroll factor + void launchKernelForUnrollFactor(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const; + + // Private templated helper function to launch kernel for specific unroll factor + template + void launchKernelForUnrollImpl(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const; + +public: + using TN = typename VectorType::type; + constexpr static int elementsPerVector = sizeof(TN) / sizeof(T); + +public: + virtual int getElementsPerVector() const + { + return this->elementsPerVector; + } + + virtual bool isValidConfig(int threadsPerBlock, int unrollFactor) const override; + + // Launch function that handles all the type-specific logic internally + virtual void launchKernel(ncclWindow_t inWindow, ncclWindow_t outWindow, void const* const residual, + ncclWindow_t residualOutWindow, void const* const weight, void const* const bias, ncclDevComm devComm, + float const eps, cudaStream_t stream) const override; + + // Constructor with dynamic block size calculation + TypedLaunchConfig(int const hidden_dim, int const num_tokens, int const rank, int const nRanks, bool useResidual, + bool useBias, int const num_sms = -1); + + nvinfer1::DataType getDataType() const + { + return tensorrt_llm::runtime::TRTDataType::value; + } + + virtual bool getValid() const + { +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + return this->valid; +#else + TLLM_LOG_WARNING("NCCL device kernels not available (NCCL version < 2.28). LaunchConfig will be invalid."); + return false; +#endif + } +}; + +std::shared_ptr makeLaunchConfig(nvinfer1::DataType dataType, int const hidden_dim, int const num_tokens, + int const rank, int const nRanks, bool useResidual, bool useBias, int const num_sms = -1); + +} // namespace tensorrt_llm::kernels::nccl_device + +#endif // TRTLLM_NCCL_DEVICE_CONFIG_H diff --git a/cpp/tensorrt_llm/kernels/nccl_device/constants.h b/cpp/tensorrt_llm/kernels/nccl_device/constants.h new file mode 100644 index 00000000000..40fc8f88002 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/constants.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRTLLM_NCCL_DEVICE_CONSTANTS_H +#define TRTLLM_NCCL_DEVICE_CONSTANTS_H + +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +// CUDA and kernel constants +constexpr int kWarpSize = 32; +constexpr int kMaxThreadsPerBlock = 1024; // Maximum block size configurable for performance. +constexpr int kMinThreadsPerBlock = kWarpSize; // Minimum block size is a warp. +constexpr int kMaxUnrollFactor = 8; // We require manual instantiation and switches. Changing the number is not good + // enough, see launcher function for details +} // namespace tensorrt_llm::kernels::nccl_device + +#endif // TRTLLM_NCCL_DEVICE_CONSTANTS_H diff --git a/cpp/tensorrt_llm/kernels/nccl_device/kernels.cuh b/cpp/tensorrt_llm/kernels/nccl_device/kernels.cuh new file mode 100644 index 00000000000..eebf63d3471 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/kernels.cuh @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRTLLM_NCCL_DEVICE_KERNELS_CUH +#define TRTLLM_NCCL_DEVICE_KERNELS_CUH + +#include "constants.h" +#include "multimem.h" +#include "nccl.h" +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) +#include "nccl_device.h" +#endif +#include "tensorrt_llm/common/assert.h" +#include "vector_types.h" +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ + constexpr unsigned int kFinalMask = 0xffffffff; +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(kFinalMask, val[i], mask, kWarpSize); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +// AllReduce deterministic multimem unrolled kernel with template parameters +template +__global__ void fusedAllReduceRMSNormKernel(ncclWindow_t input_win, ncclWindow_t output_win, const TN* residual, + ncclWindow_t residual_out_win, const TN* weight, const TN* bias, int const startToken, int const hidden_size, + int const tokensPerRank, ncclDevComm devComm, float const eps) +{ + + using accType = typename VectorType::accType; + ncclLsaBarrierSession bar{ + ncclCoopCta(), devComm, ncclTeamLsa(devComm), devComm.lsaBarrier, blockIdx.x, true, devComm.lsaMultimem}; + bar.sync(ncclCoopCta(), cuda::memory_order_relaxed); + + // Calculate which token this block should process +#pragma unroll 1 + for (int token_offset = blockIdx.x; token_offset < tokensPerRank; token_offset += gridDim.x) + { + int const token_id = token_offset + startToken; + // Calculate elements per vector type + constexpr int elems_per_vec = sizeof(TN) / sizeof(T); + + int const token_base_offset = token_id * hidden_size; // Base offset for this token in T elements + + // Calculate warp and lane within this block + int const warp_id = threadIdx.x / kWarpSize; + int const lane_id = threadIdx.x % kWarpSize; + + // Ensure warp striding through memory within the token + // Scale offsets by elements per vector since each thread handles more data with vectors + int const warp_offset = (warp_id * kWarpSize * Nunroll) * elems_per_vec; + int const lane_offset = lane_id * elems_per_vec; + + int const base_offset_T = warp_offset + lane_offset + token_base_offset; + + // Get aligned pointers for vector types + TN* send_ptr = reinterpret_cast(ncclGetMultimemPointer(input_win, 0, devComm.lsaMultimem)); + TN* recv_ptr = reinterpret_cast( + oneShot ? ncclGetLocalPointer(output_win, 0) : ncclGetMultimemPointer(output_win, 0, devComm.lsaMultimem)); + + assert(send_ptr != nullptr); + assert(recv_ptr != nullptr); + TN* residual_out = nullptr; + if constexpr (useResidual) + { + residual_out = oneShot + ? reinterpret_cast(ncclGetLocalPointer(residual_out_win, 0)) + : reinterpret_cast(ncclGetMultimemPointer(residual_out_win, 0, devComm.lsaMultimem)); + + assert(residual != nullptr); + assert(residual_out != nullptr); + } + if constexpr (useBias) + { + assert(bias != nullptr); + } + + // Process exactly the elements assigned to this thread + TN v[Nunroll]; + accType local_sum_squares = accType{0}; // For RMS calculation +#pragma unroll Nunroll + for (int i = 0; i < Nunroll; i++) + { + int const stride_offset = i * kWarpSize * elems_per_vec; // Scale stride by elements per vector + size_t const offset_T = base_offset_T + stride_offset; + size_t const offset_TN = offset_T / elems_per_vec; // Convert to vector offset + + v[i] = multimemLoadSum(reinterpret_cast(send_ptr + offset_TN)); + } + +#pragma unroll Nunroll + for (int i = 0; i < Nunroll; i++) + { + int const stride_offset = i * kWarpSize * elems_per_vec; // Scale stride by elements per vector + size_t const offset_T = base_offset_T + stride_offset; + size_t const offset_TN = offset_T / elems_per_vec; // Convert to vector offset + // The residual is the allreduced result (v) plus the input residual + T const* residual_elem = useResidual ? reinterpret_cast(residual + offset_TN) : nullptr; + T* v_elem = reinterpret_cast(&v[i]); + +#pragma unroll elems_per_vec + for (int j = 0; j < elems_per_vec; ++j) + { + if constexpr (useResidual) + { + // Residual = allreduced_result + input_residual + v_elem[j] + = static_cast(static_cast(v_elem[j]) + static_cast(residual_elem[j])); + } + + // Calculate sum of squares using residual values + accType value = static_cast(v_elem[j]); + local_sum_squares += value * value; + } + } + +#pragma unroll Nunroll + for (int i = 0; i < Nunroll; i++) + { + int const stride_offset = i * kWarpSize * elems_per_vec; // Scale stride by elements per vector + size_t const offset_T = base_offset_T + stride_offset; + size_t const offset_TN = offset_T / elems_per_vec; // Convert to vector offset + if (!oneShot) + multimemStore(reinterpret_cast(residual_out + offset_TN), v[i]); + else + residual_out[offset_TN] = v[i]; + } + + // RMS normalization: each block processes exactly one token + __shared__ accType rms; + blockReduceSumV2(&local_sum_squares); + if (threadIdx.x == 0) + { + accType const block_sum_squares = local_sum_squares; + rms = rsqrtf((block_sum_squares / static_cast(hidden_size)) + eps); + } + // Synchronize again to ensure RMS is computed before using it + __syncthreads(); + + // Apply RMS normalization with per-token weight and bias +#pragma unroll Nunroll + for (int i = 0; i < Nunroll; i++) + { + // Get the position within the hidden dimension for this thread + // Since each block processes one token, we just need the position within that token + int const hidden_dim_pos = warp_offset + lane_offset + i * kWarpSize * elems_per_vec; + + // Index into weight and bias arrays: just the position within hidden dimension + TN weight_vec = weight[hidden_dim_pos / elems_per_vec]; + TN bias_vec = useBias ? bias[hidden_dim_pos / elems_per_vec] : TN{0}; + + // Apply RMS normalization: v = (v / rms) * weight + bias + // Unroll vector types and handle each element individually with proper type promotion + T* v_elem = reinterpret_cast(&v[i]); + T* weight_elem = reinterpret_cast(&weight_vec); + T* bias_elem = reinterpret_cast(&bias_vec); + +#pragma unroll elems_per_vec + for (int j = 0; j < elems_per_vec; ++j) + { + // Promote to accType for intermediate calculations + accType v_acc = static_cast(v_elem[j]); + accType weight_acc = static_cast(weight_elem[j]); + accType bias_acc = static_cast(bias_elem[j]); + + // Apply RMS normalization: v = (v / rms) * weight + bias + accType normalized = v_acc * rms; + accType weighted = normalized * weight_acc; + accType result = weighted + bias_acc; + + // Cast back to T + v_elem[j] = static_cast(result); + } + } +#pragma unroll Nunroll + for (int i = 0; i < Nunroll; i++) + { + int const stride_offset = i * kWarpSize * elems_per_vec; // Scale stride by elements per vector + size_t const offset_T = base_offset_T + stride_offset; + size_t const offset_TN = offset_T / elems_per_vec; // Convert to vector offset + if (!oneShot) + multimemStore(reinterpret_cast(recv_ptr + offset_TN), v[i]); + else + recv_ptr[offset_TN] = v[i]; + } + } + bar.sync(ncclCoopCta(), cuda::memory_order_release); +} + +#endif // NCCL_VERSION_CODE >= NCCL_VERSION(2,28,0) + +} // namespace tensorrt_llm::kernels::nccl_device + +#endif // TRTLLM_NCCL_DEVICE_KERNELS_CUH diff --git a/cpp/tensorrt_llm/kernels/nccl_device/multimem.h b/cpp/tensorrt_llm/kernels/nccl_device/multimem.h new file mode 100644 index 00000000000..d2e75f12310 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/multimem.h @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRTLLM_NCCL_DEVICE_MULTIMEM_H +#define TRTLLM_NCCL_DEVICE_MULTIMEM_H + +#include +#if CUDART_VERSION >= 11000 +#include +#endif +#include "constants.h" +#include "vector_types.h" +#include +#include +#include +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +// Architecture-specific feature detection based on PTX ISA documentation +// PTX ISA 8.1: Basic multimem support (sm_90+) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) +#define ARCH_HAS_MULTIMEM 1 +#else +#define ARCH_HAS_MULTIMEM 0 +#endif + +// PTX ISA 8.2: .acc::f32 qualifier support (sm_90+) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) +#define ARCH_HAS_MULTIMEM_ACC_F32 1 +#else +#define ARCH_HAS_MULTIMEM_ACC_F32 0 +#endif + +// PTX ISA 8.6: FP8 types and .acc::f16 qualifier support +// Supported on sm_100a, sm_101a (sm_110a), sm_120a, sm_121a +// And family-specific architectures sm_100f+, sm_101f+ (sm_110f+) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) +#define ARCH_HAS_MULTIMEM_FP8 1 +#define ARCH_HAS_MULTIMEM_ACC_F16 1 +#else +#define ARCH_HAS_MULTIMEM_FP8 0 +#define ARCH_HAS_MULTIMEM_ACC_F16 0 +#endif + +// Basic data type support (independent of multimem) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) +#define ARCH_HAS_BF16 1 +#else +#define ARCH_HAS_BF16 0 +#endif + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890) +#define ARCH_HAS_FP8 1 +#else +#define ARCH_HAS_FP8 0 +#endif + +// Base template for multimemLoadSum - device assert for SM < 90 +template +__device__ __forceinline__ valT multimemLoadSum(ptrT const* addr) +{ + assert(false && "multimemLoadSum requires SM90+ (Hopper) with multimem support. This operation cannot be emulated on older architectures."); + return valT{}; // Unreachable, but satisfies return type requirement +} + +// SM90+ specializations for supported types +#if ARCH_HAS_MULTIMEM +// Basic multimem support (PTX ISA 8.1) +template <> +__device__ __forceinline__ double multimemLoadSum(double const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + double result; + asm volatile("multimem.ld_reduce.global.add.f64 %0, [%1];" : "=d"(result) : "l"(multimem_addr) : "memory"); + return result; +} + +template <> +__device__ __forceinline__ float multimemLoadSum(float const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + float result; + asm volatile("multimem.ld_reduce.global.add.f32 %0, [%1];" : "=f"(result) : "l"(multimem_addr) : "memory"); + return result; +} + +template <> +__device__ __forceinline__ float2 multimemLoadSum(float const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + float2 result; + asm volatile("multimem.ld_reduce.global.add.v2.f32 {%0, %1}, [%2];" + : "=f"(result.x), "=f"(result.y) + : "l"(multimem_addr) + : "memory"); + return result; +} + +template <> +__device__ __forceinline__ float4 multimemLoadSum(float const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + float4 result; + asm volatile("multimem.ld_reduce.global.add.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(result.x), "=f"(result.y), "=f"(result.z), "=f"(result.w) + : "l"(multimem_addr) + : "memory"); + return result; +} + +#if ARCH_HAS_MULTIMEM_ACC_F32 +// .acc::f32 qualifier support (PTX ISA 8.2) +template <> +__device__ __forceinline__ HalfVector multimemLoadSum(half const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + HalfVector result; + asm volatile("multimem.ld_reduce.global.add.v4.f16x2.acc::f32 {%0, %1, %2, %3}, [%4];" + : "=r"(result.data.x), "=r"(result.data.y), "=r"(result.data.z), "=r"(result.data.w) + : "l"(multimem_addr) + : "memory"); + return result; +} + +template <> +__device__ __forceinline__ BFloat16Vector multimemLoadSum<__nv_bfloat16, BFloat16Vector>(__nv_bfloat16 const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + BFloat16Vector result; + asm volatile("multimem.ld_reduce.global.add.v4.bf16x2.acc::f32 {%0, %1, %2, %3}, [%4];" + : "=r"(result.data.x), "=r"(result.data.y), "=r"(result.data.z), "=r"(result.data.w) + : "l"(multimem_addr) + : "memory"); + return result; +} +#endif // ARCH_HAS_MULTIMEM_ACC_F32 + +#if ARCH_HAS_MULTIMEM_FP8 +// FP8 types and .acc::f16 qualifier support (PTX ISA 8.6) +template <> +__device__ __forceinline__ FP8E5M2x4Vector multimemLoadSum<__nv_fp8_e5m2, FP8E5M2x4Vector>(__nv_fp8_e5m2 const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + FP8E5M2x4Vector result; + asm volatile("multimem.ld_reduce.global.add.v4.e5m2x4.acc::f16 {%0, %1, %2, %3}, [%4];" + : "=r"(result.data.x), "=r"(result.data.y), "=r"(result.data.z), "=r"(result.data.w) + : "l"(multimem_addr) + : "memory"); + return result; +} + +template <> +__device__ __forceinline__ FP8E4M3x4Vector multimemLoadSum<__nv_fp8_e4m3, FP8E4M3x4Vector>(__nv_fp8_e4m3 const* addr) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + FP8E4M3x4Vector result; + asm volatile("multimem.ld_reduce.global.add.v4.e4m3x4.acc::f16 {%0, %1, %2, %3}, [%4];" + : "=r"(result.data.x), "=r"(result.data.y), "=r"(result.data.z), "=r"(result.data.w) + : "l"(multimem_addr) + : "memory"); + return result; +} +#endif // ARCH_HAS_MULTIMEM_FP8 +#endif // ARCH_HAS_MULTIMEM + +// Base template for multimemStore - device assert for SM < 90 +template +__device__ __forceinline__ void multimemStore(ptrT* addr, valT const val) +{ + assert(false && "multimemStore requires SM90+ (Hopper) with multimem support. This operation cannot be emulated on older architectures."); +} + +// SM90+ specializations for supported types +#if ARCH_HAS_MULTIMEM +// Basic multimem support (PTX ISA 8.1) +template <> +__device__ __forceinline__ void multimemStore(double* addr, double const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.f64 [%0], %1;" : : "l"(multimem_addr), "d"(val) : "memory"); +} + +template <> +__device__ __forceinline__ void multimemStore(float* addr, float const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.f32 [%0], %1;" : : "l"(multimem_addr), "f"(val) : "memory"); +} + +template <> +__device__ __forceinline__ void multimemStore(float* addr, float2 const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v2.f32 [%0], {%1, %2};" : : "l"(multimem_addr), "f"(val.x), "f"(val.y) : "memory"); +} + +template <> +__device__ __forceinline__ void multimemStore(float* addr, float4 const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v4.f32 [%0], {%1, %2, %3, %4};" + : + : "l"(multimem_addr), "f"(val.x), "f"(val.y), "f"(val.z), "f"(val.w) + : "memory"); +} + +template <> +__device__ __forceinline__ void multimemStore(half* addr, HalfVector const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v4.f16x2 [%0], {%1,%2,%3,%4};" + : + : "l"(multimem_addr), "r"(val.data.x), "r"(val.data.y), "r"(val.data.z), "r"(val.data.w) + : "memory"); +} + +#if ARCH_HAS_MULTIMEM_ACC_F32 +template <> +__device__ __forceinline__ void multimemStore<__nv_bfloat16, BFloat16Vector>( + __nv_bfloat16* addr, const BFloat16Vector val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" + : + : "l"(multimem_addr), "r"(val.data.x), "r"(val.data.y), "r"(val.data.z), "r"(val.data.w) + : "memory"); +} +#endif // ARCH_HAS_MULTIMEM_ACC_F32 + +#if ARCH_HAS_MULTIMEM_FP8 +// FP8 types support (PTX ISA 8.6) +template <> +__device__ __forceinline__ void multimemStore<__nv_fp8_e5m2, FP8E5M2x4Vector>( + __nv_fp8_e5m2* addr, FP8E5M2x4Vector const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v4.e5m2x4 [%0], {%1,%2,%3,%4};" + : + : "l"(multimem_addr), "r"(val.data.x), "r"(val.data.y), "r"(val.data.z), "r"(val.data.w) + : "memory"); +} + +template <> +__device__ __forceinline__ void multimemStore<__nv_fp8_e4m3, FP8E4M3x4Vector>( + __nv_fp8_e4m3* addr, FP8E4M3x4Vector const val) +{ + uintptr_t const multimem_addr = reinterpret_cast(addr); + asm volatile("multimem.st.global.v4.e4m3x4 [%0], {%1,%2,%3,%4};" + : + : "l"(multimem_addr), "r"(val.data.x), "r"(val.data.y), "r"(val.data.z), "r"(val.data.w) + : "memory"); +} +#endif // ARCH_HAS_MULTIMEM_FP8 +#endif // ARCH_HAS_MULTIMEM + +} // namespace tensorrt_llm::kernels::nccl_device + +#endif // TRTLLM_NCCL_DEVICE_MULTIMEM_H diff --git a/cpp/tensorrt_llm/kernels/nccl_device/vector_types.h b/cpp/tensorrt_llm/kernels/nccl_device/vector_types.h new file mode 100644 index 00000000000..ee45a65dbb5 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/nccl_device/vector_types.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::kernels::nccl_device +{ + +// Helper struct to distinguish half vector operations +struct HalfVector +{ + uint4 data; + using accType = float; // Accumulation type for precision + + __device__ __forceinline__ HalfVector(uint4 val) + : data(val) + { + } + + __device__ __forceinline__ HalfVector() + : data({0, 0, 0, 0}) + { + } + + __device__ __forceinline__ HalfVector(int) + : data({0, 0, 0, 0}) + { + } // For {0} initialization +}; + +// Helper struct to distinguish bfloat16 vector operations +struct BFloat16Vector +{ + uint4 data; + using accType = float; // Accumulation type for precision + + __device__ __forceinline__ BFloat16Vector(uint4 val) + : data(val) + { + } + + __device__ __forceinline__ BFloat16Vector() + : data({0, 0, 0, 0}) + { + } + + __device__ __forceinline__ BFloat16Vector(int) + : data({0, 0, 0, 0}) + { + } // For {0} initialization +}; + +// Helper struct to distinguish FP8 e5m2x4 vector operations +// e5m2x4 means 4 e5m2 elements packed into uint4 (32 bits total) +struct FP8E5M2x4Vector +{ + uint4 data; + using accType = float; // Accumulation type for precision + + __device__ __forceinline__ FP8E5M2x4Vector(uint4 val) + : data(val) + { + } + + __device__ __forceinline__ FP8E5M2x4Vector() + : data({0, 0, 0, 0}) + { + } + + __device__ __forceinline__ FP8E5M2x4Vector(int) + : data({0, 0, 0, 0}) + { + } // For {0} initialization +}; + +// Helper struct to distinguish FP8 e4m3x4 vector operations +// e4m3x4 means 4 e4m3 elements packed into uint4 (32 bits total) +struct FP8E4M3x4Vector +{ + uint4 data; + using accType = float; // Accumulation type for precision + + __device__ __forceinline__ FP8E4M3x4Vector(uint4 val) + : data(val) + { + } + + __device__ __forceinline__ FP8E4M3x4Vector() + : data({0, 0, 0, 0}) + { + } + + __device__ __forceinline__ FP8E4M3x4Vector(int) + : data({0, 0, 0, 0}) + { + } // For {0} initialization +}; + +// Vector type mapping +template +struct VectorType +{ + using type = T; // Default to scalar (elementsPerVector = 1) + using accType = T; +}; + +// Specializations for vectorized types +template <> +struct VectorType +{ + using type = float4; // Use float4 for best vectorization (elementsPerVector = 4) + using accType = float; +}; + +template <> +struct VectorType +{ + using type = double; // Use double for vectorization since that is the only multimem supported version + using accType = double; +}; + +template <> +struct VectorType +{ + using type = HalfVector; // Use HalfVector for proper half arithmetic + using accType = float; // Always use float for accumulation for numerical stability +}; + +template <> +struct VectorType<__nv_bfloat16> +{ + using type = BFloat16Vector; // Use BFloat16Vector for proper bfloat16 arithmetic + using accType = float; // Always use float for accumulation for numerical stability +}; + +template <> +struct VectorType<__nv_fp8_e5m2> +{ + using type = FP8E5M2x4Vector; // Use FP8E5M2x4Vector for FP8 e5m2x4 arithmetic + using accType = float; +}; + +template <> +struct VectorType<__nv_fp8_e4m3> +{ + using type = FP8E4M3x4Vector; // Use FP8E4M3x4Vector for FP8 e4m3x4 arithmetic + using accType = float; +}; + +} // namespace tensorrt_llm::kernels::nccl_device diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp index e0f2d5cce2e..4ec3f62c381 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "ub_allocator.h" +#include "nccl.h" #include "tensorrt_llm/common/opUtils.h" #include #include @@ -131,11 +132,79 @@ NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper() return *mNCCLHelper; } +NCCLUserBufferAllocator::~NCCLUserBufferAllocator() +{ + // Deallocate buffers + auto& ncclHelper = getNCCLHelper(); + if (ncclHelper.isLoaded()) + { + auto ncclMemFreeFunc = ncclHelper.getNCCLMemFree(); + auto ncclCommWindowDeregisterFunc = ncclHelper.getNCCLCommWindowDeregister(); + if (ncclCommWindowDeregisterFunc == nullptr) + { + TLLM_LOG_WARNING("NCCL buffer windows cannot be released."); + } + for (auto buffer : mBuffers) + { + buffer.size = 0; + if (ncclCommWindowDeregisterFunc != nullptr) + { + NCCLCHECK(ncclCommWindowDeregisterFunc(*mComm, buffer.window)); + } + if (ncclMemFreeFunc != nullptr) + { + NCCLCHECK(ncclMemFreeFunc(buffer.addr)); + } + } + auto ncclDevCommDestroyFunc = ncclHelper.getNCCLDevCommDestroy(); + if (ncclDevCommDestroyFunc) + { + for (auto x : mDevCommBlockID) + { + ncclDevComm devComm = x.second; + NCCLCHECK(ncclDevCommDestroyFunc(*mComm, devComm)); + } + } + } +} + +ncclDevComm NCCLUserBufferAllocator::getNCCLDevComm(int const numLsaBarriers) +{ + constexpr bool multimemSupport = true; + auto commIter = mDevCommBlockID.find(numLsaBarriers); // codespell:ignore word + if (commIter == mDevCommBlockID.end()) + { + ncclDevComm devComm; + ncclDevCommRequirements reqs = {0}; + memset(&reqs, 0, sizeof(ncclDevCommRequirements)); + reqs.lsaBarrierCount = numLsaBarriers; + reqs.lsaMultimem = multimemSupport; + + auto& ncclHelper = getNCCLHelper(); + auto ncclDevCommCreateFunc = ncclHelper.getNCCLDevCommCreate(); + + ncclResult_t ncclError = ncclDevCommCreateFunc(*mComm, &reqs, &devComm); + TLLM_CHECK_WITH_INFO( + ncclError == ncclSuccess, "Failed to create NCCL device communicator: %s", ncclGetErrorString(ncclError)); + mDevCommBlockID[numLsaBarriers] = devComm; + } + commIter = mDevCommBlockID.find(numLsaBarriers); + if (commIter == mDevCommBlockID.end()) + { + TLLM_THROW("NCCL cannot create required device communicator"); + } + return commIter->second; +} + // NCCLHelper implementation NCCLHelper::NCCLHelper() : mLibraryHandle(nullptr) , mNCCLCommWindowRegister(nullptr) , mNCCLMemAlloc(nullptr) + , mNCCLCommWindowDeregister(nullptr) + , mNCCLMemFree(nullptr) + , mNCCLDevCommCreate(nullptr) + , mNCCLDevCommDestroy(nullptr) , mIsLoaded(false) { loadNCCLLibrary(); @@ -166,12 +235,24 @@ void NCCLHelper::loadNCCLLibrary() for (int i = 0; libraryNames[i] != nullptr; ++i) { + TLLM_LOG_INFO("Attempting to load NCCL library: %s", libraryNames[i]); mLibraryHandle = loadLibraryHandle(libraryNames[i]); if (mLibraryHandle) { TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]); + + // Get the actual path of the loaded library + Dl_info info; + if (dladdr(mLibraryHandle, &info) && info.dli_fname) + { + TLLM_LOG_INFO("NCCL library loaded from: %s", info.dli_fname); + } break; } + else + { + TLLM_LOG_WARNING("Failed to load NCCL library: %s", libraryNames[i]); + } } if (!mLibraryHandle) @@ -186,12 +267,40 @@ void NCCLHelper::loadNCCLLibrary() mNCCLMemAlloc = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclMemAlloc")); - if (mNCCLCommWindowRegister == nullptr) + mNCCLCommWindowDeregister = reinterpret_cast( + getSymbolAddress(mLibraryHandle, "ncclCommWindowDeregister")); + + mNCCLMemFree = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclMemFree")); + + // Try to resolve device communicator functions using proper symbol resolution + mNCCLDevCommCreate = resolveNCCLDevCommCreate(mLibraryHandle); + mNCCLDevCommDestroy = resolveNCCLDevCommDestroy(mLibraryHandle); + + if (mNCCLCommWindowRegister == nullptr or mNCCLCommWindowDeregister == nullptr) { TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported."); } - if (mNCCLMemAlloc) + if (mNCCLDevCommCreate == nullptr or mNCCLDevCommDestroy == nullptr) + { + TLLM_LOG_WARNING( + "Failed to load ncclDevCommCreate/ncclDevCommDestroy symbols, NCCL fused kernels will not be " + "supported. Ensure NCCL version >= 2.28."); + if (mNCCLDevCommCreate == nullptr) + { + TLLM_LOG_WARNING("ncclDevCommCreate symbol not found (tried both C and C++ mangled names)"); + } + if (mNCCLDevCommDestroy == nullptr) + { + TLLM_LOG_WARNING("ncclDevCommDestroy symbol not found (tried both C and C++ mangled names)"); + } + } + else + { + TLLM_LOG_INFO("Successfully loaded ncclDevCommCreate and ncclDevCommDestroy symbols"); + } + + if (mNCCLMemAlloc and mNCCLMemFree) { mIsLoaded = true; } @@ -229,6 +338,72 @@ void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName) #endif } +// Robust symbol resolution for device communicator functions +NCCLHelper::ncclDevCommCreateFunc NCCLHelper::resolveNCCLDevCommCreate(void* handle) +{ + if (!handle) + return nullptr; + + // Try C-style symbol first (preferred) + void* symbol = getSymbolAddress(handle, "ncclDevCommCreate"); + if (symbol) + { + TLLM_LOG_DEBUG("Found ncclDevCommCreate with C linkage"); + return reinterpret_cast(symbol); + } + + // Try common C++ mangled variants (fallback) + char const* mangledNames[] + = {"_Z17ncclDevCommCreateP8ncclCommPK23ncclDevCommRequirementsP11ncclDevComm", // GCC/Clang + "?ncclDevCommCreate@@YAHPAUncclComm@@PBUncclDevCommRequirements@@PAUncclDevComm@@@Z", // MSVC + nullptr}; + + for (int i = 0; mangledNames[i] != nullptr; ++i) + { + symbol = getSymbolAddress(handle, mangledNames[i]); + if (symbol) + { + TLLM_LOG_WARNING("Found ncclDevCommCreate with C++ mangled name (fragile): %s", mangledNames[i]); + return reinterpret_cast(symbol); + } + } + + TLLM_LOG_DEBUG("ncclDevCommCreate not found with any known symbol name"); + return nullptr; +} + +NCCLHelper::ncclDevCommDestroyFunc NCCLHelper::resolveNCCLDevCommDestroy(void* handle) +{ + if (!handle) + return nullptr; + + // Try C-style symbol first (preferred) + void* symbol = getSymbolAddress(handle, "ncclDevCommDestroy"); + if (symbol) + { + TLLM_LOG_DEBUG("Found ncclDevCommDestroy with C linkage"); + return reinterpret_cast(symbol); + } + + // Try common C++ mangled variants (fallback) + char const* mangledNames[] = {"_Z18ncclDevCommDestroyP8ncclCommPK11ncclDevComm", // GCC/Clang + "?ncclDevCommDestroy@@YAHPAUncclComm@@PBUncclDevComm@@@Z", // MSVC + nullptr}; + + for (int i = 0; mangledNames[i] != nullptr; ++i) + { + symbol = getSymbolAddress(handle, mangledNames[i]); + if (symbol) + { + TLLM_LOG_WARNING("Found ncclDevCommDestroy with C++ mangled name (fragile): %s", mangledNames[i]); + return reinterpret_cast(symbol); + } + } + + TLLM_LOG_DEBUG("ncclDevCommDestroy not found with any known symbol name"); + return nullptr; +} + NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister() { return mNCCLCommWindowRegister; @@ -239,6 +414,26 @@ NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc() return mNCCLMemAlloc; } +NCCLHelper::ncclCommWindowDeregisterFunc NCCLHelper::getNCCLCommWindowDeregister() +{ + return mNCCLCommWindowDeregister; +} + +NCCLHelper::ncclMemFreeFunc NCCLHelper::getNCCLMemFree() +{ + return mNCCLMemFree; +} + +NCCLHelper::ncclDevCommCreateFunc NCCLHelper::getNCCLDevCommCreate() +{ + return mNCCLDevCommCreate; +} + +NCCLHelper::ncclDevCommDestroyFunc NCCLHelper::getNCCLDevCommDestroy() +{ + return mNCCLDevCommDestroy; +} + bool NCCLHelper::isLoaded() const { return mIsLoaded; @@ -246,4 +441,26 @@ bool NCCLHelper::isLoaded() const bool UserBufferAllocator::use_nccl_symmetric = false; +std::shared_ptr +NCCLUserBufferAllocator::getCachedNCCLDeviceLaunchConfig(nvinfer1::DataType dataType, int const hidden_dim, + int const num_tokens, int const rank, int const nRanks, bool useResidual, bool useBias) +{ + // Create cache key + LaunchConfigKey key{dataType, hidden_dim, num_tokens, rank, nRanks, useResidual, useBias}; + + // Check if config already exists in cache + auto it = mLaunchConfigCache.find(key); + if (it != mLaunchConfigCache.end()) + { + return it->second; // Return cached config + } + + // Create new config and cache it + auto config = tensorrt_llm::kernels::nccl_device::makeLaunchConfig( + dataType, hidden_dim, num_tokens, rank, nRanks, useResidual, useBias); + + mLaunchConfigCache[key] = config; + return config; +} + }; // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h index 4cc91497054..468bc389d1c 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h @@ -14,8 +14,17 @@ * limitations under the License. */ #pragma once +#include "nccl.h" +#include "nccl_device.h" +#include "tensorrt_llm/kernels/nccl_device/config.h" #include "tensorrt_llm/runtime/worldConfig.h" +#include #include +#include + +// Forward declarations for NCCL device communicator types +struct ncclDevComm; +struct ncclDevCommRequirements; #if ENABLE_MULTI_DEVICE #include "nccl.h" #include "userbuffers.h" @@ -89,6 +98,11 @@ class NCCLHelper // Dynamic loading function type definition using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int); using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t); + using ncclCommWindowDeregisterFunc = ncclResult_t (*)(ncclComm_t, ncclWindow_t); + using ncclMemFreeFunc = ncclResult_t (*)(void(*)); + using ncclDevCommCreateFunc + = ncclResult_t (*)(ncclComm_t, struct ncclDevCommRequirements const(*), struct ncclDevComm(*)); + using ncclDevCommDestroyFunc = ncclResult_t (*)(ncclComm_t, ncclDevComm); // Get function pointer for ncclCommWindowRegister ncclCommWindowRegisterFunc getNCCLCommWindowRegister(); @@ -96,6 +110,11 @@ class NCCLHelper // Get function pointer for ncclMemAlloc ncclMemAllocFunc getNCCLMemAlloc(); + ncclCommWindowDeregisterFunc getNCCLCommWindowDeregister(); + ncclMemFreeFunc getNCCLMemFree(); + ncclDevCommCreateFunc getNCCLDevCommCreate(); + ncclDevCommDestroyFunc getNCCLDevCommDestroy(); + // Check if NCCL library is successfully loaded bool isLoaded() const; @@ -104,6 +123,10 @@ class NCCLHelper void* loadLibraryHandle(char const* libName); void* getSymbolAddress(void* handle, char const* symbolName); + // Robust symbol resolution methods + ncclDevCommCreateFunc resolveNCCLDevCommCreate(void* handle); + ncclDevCommDestroyFunc resolveNCCLDevCommDestroy(void* handle); + #ifdef _WIN32 HMODULE mLibraryHandle; #else @@ -112,6 +135,10 @@ class NCCLHelper ncclCommWindowRegisterFunc mNCCLCommWindowRegister; ncclMemAllocFunc mNCCLMemAlloc; + ncclCommWindowDeregisterFunc mNCCLCommWindowDeregister; + ncclMemFreeFunc mNCCLMemFree; + ncclDevCommCreateFunc mNCCLDevCommCreate; + ncclDevCommDestroyFunc mNCCLDevCommDestroy; bool mIsLoaded; }; @@ -124,9 +151,40 @@ class NCCLUserBufferAllocator : public UserBufferAllocator // Get shared NCCLHelper instance static NCCLHelper& getNCCLHelper(); + ~NCCLUserBufferAllocator(); + + ncclDevComm getNCCLDevComm(int const numLsaBarriers); + + // Cached NCCL device launch config functionality + std::shared_ptr getCachedNCCLDeviceLaunchConfig( + nvinfer1::DataType dataType, int const hidden_dim, int const num_tokens, int const rank, int const nRanks, + bool useResidual, bool useBias); + private: std::shared_ptr mComm; static std::unique_ptr mNCCLHelper; + std::map mDevCommBlockID; + + // Cache for fused allreduce launch configs + struct LaunchConfigKey + { + nvinfer1::DataType dataType; + int hidden_dim; + int num_tokens; + int rank; + int nRanks; + bool useResidual; + bool useBias; + + bool operator<(LaunchConfigKey const& other) const + { + return std::tie(dataType, hidden_dim, num_tokens, rank, nRanks, useResidual, useBias) + < std::tie(other.dataType, other.hidden_dim, other.num_tokens, other.rank, other.nRanks, + other.useResidual, other.useBias); + } + }; + + std::map> mLaunchConfigCache; }; #else diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index 4241cf8d859..847e3c4f561 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -137,7 +137,8 @@ bool AllreducePlugin::supportsFormatCombination( int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { int base_inputs = 0; - if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::UB) + if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::UB + || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC || mStrategy == AllReduceStrategyType::NCCL_DEVICE) { base_inputs = 1; } @@ -169,7 +170,9 @@ bool AllreducePlugin::supportsFormatCombination( TLLM_CHECK(nbInputs == (base_inputs + fusion_op_extra_inputs)); - if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB && pos == 1) + if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB + && mStrategy != AllReduceStrategyType::NCCL_SYMMETRIC && mStrategy != AllReduceStrategyType::NCCL_DEVICE + && pos == 1) { return (inOut[pos].type == nvinfer1::DataType::kINT64) && (inOut[pos].format == TensorFormat::kLINEAR); } @@ -341,6 +344,14 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe { runtimeStrategy = AllReduceStrategyType::UB; } + else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + { + runtimeStrategy = AllReduceStrategyType::NCCL_SYMMETRIC; + } + else if (mStrategy == AllReduceStrategyType::NCCL_DEVICE) + { + runtimeStrategy = AllReduceStrategyType::NCCL_DEVICE; + } else { runtimeStrategy = selectImplementation(size, mGroup.size(), mType); @@ -370,6 +381,16 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank); break; } + case AllReduceStrategyType::NCCL_SYMMETRIC: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank); + break; + } + case AllReduceStrategyType::NCCL_DEVICE: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_DEVICE", rank); + break; + } default: break; } @@ -798,10 +819,17 @@ int AllreducePlugin::initialize() noexcept TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); mNcclComm = getComm(mGroup); - if (mStrategy != AllReduceStrategyType::NCCL) + // Skip topology detection for strategies that manage connectivity internally + if (!shouldSkipTopologyDetection(mStrategy)) { initGroupTopology(); } + else + { + // For strategies that skip topology detection, assume connectivity is supported + mIsP2PSupported = true; + mIsNVLINKSupported = true; + } TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); return 0; diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 469aafe6476..e8ef4862793 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -482,7 +482,8 @@ void initBindings(pybind11::module_& m) .value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB) .value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT) .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT) - .value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC); + .value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC) + .value("NCCL_DEVICE", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_DEVICE); // Initialize MoeLoadBalancer bindings initMoeBindings(m); diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index b1d5aee28ac..3d927c6a4f8 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -25,6 +25,7 @@ #include "tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h" #include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h" #include "tensorrt_llm/kernels/customAllReduceKernels.h" +#include "tensorrt_llm/kernels/nccl_device/config.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/kernels/userbuffers/ub_interface.h" #include "tensorrt_llm/runtime/mcastDeviceMemory.h" @@ -291,6 +292,8 @@ class AllreduceOp case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL_SYMMETRIC: return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias); + case AllReduceStrategyType::NCCL_DEVICE: + return runNCCLAllReduceDeviceFusion(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::MIN_LATENCY: case AllReduceStrategyType::ONESHOT: case AllReduceStrategyType::TWOSHOT: @@ -309,11 +312,17 @@ class AllreduceOp { mNcclComm = getComm(mGroup); } - if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB) + // Skip topology detection for strategies that manage connectivity internally + if (!shouldSkipTopologyDetection(mStrategy)) { - initGroupTopology(); } + else + { + // For strategies that skip topology detection, assume connectivity is supported + mIsP2PSupported = true; + mIsNVLINKSupported = true; + } TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, getRank()); return 0; @@ -335,6 +344,14 @@ class AllreduceOp TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::ub::ub_is_initialized(), "UserBuffer has not been initialized!"); auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); + if (ub_buffer0.invalid()) + { + auto [symmetric_input, symmetric_ub_buffer0] + = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), + cudaMemcpyDeviceToDevice, stream); + ub_buffer0 = symmetric_ub_buffer0; + } TLLM_CHECK(!ub_buffer0.invalid()); auto ub_comm = ub_manager.comm(); @@ -435,13 +452,12 @@ class AllreduceOp std::vector runNCCLAllReduceSymmetric(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, - torch::optional const& scale, torch::optional const& bias) + torch::optional const& scale, torch::optional const& bias) noexcept { auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); - auto ub_tensor0 = input; auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); if (ub_buffer0.invalid()) { @@ -450,23 +466,14 @@ class AllreduceOp cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), cudaMemcpyDeviceToDevice, stream); ub_buffer0 = symmetric_ub_buffer0; - ub_tensor0 = symmetric_input; } TLLM_CHECK(!ub_buffer0.invalid()); auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); - std::visit(overloaded{[&, norm_out_ = norm_out](std::shared_ptr& rawComm) - { - NCCLCHECK_THROW(ncclAllReduce(ub_buffer0.addr, norm_out_.mutable_data_ptr(), size, - (*getDtypeMap())[mType], ncclSum, *rawComm, stream)); - }, - [&, norm_out_ = norm_out](c10::intrusive_ptr& torchPg) - { - PGCHECK_THROW(PgHelper{torchPg}.allreduce(ub_tensor0, {c10d::ReduceOp::SUM})); - std::ignore = norm_out_.copy_(ub_tensor0, true); - }}, - mNcclComm); + auto& rawComm = std::get>(mNcclComm); + NCCLCHECK(ncclAllReduce( + ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *rawComm, stream)); if (mOp == AllReduceFusionOp::NONE) { @@ -477,6 +484,115 @@ class AllreduceOp return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out); } + std::vector runNCCLAllReduceDeviceFusion(torch::Tensor const& input, + torch::optional const& residual, torch::optional const& norm_weight, + torch::optional const& scale, torch::optional const& bias) + { + auto const myRank = getRank(); + TLLM_LOG_DEBUG("runNCCLAllReduceDeviceFusion: rank=%d, fusion_op=%s", myRank, + tensorrt_llm::kernels::toString(mOp).c_str()); + + TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::ub::ub_is_initialized(), + "UserBuffer has not been initialized (required for NCCL_DEVICE)"); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + int size = input.numel(); + auto ub_tensor0 = input; + auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); + auto& allocator = tensorrt_llm::runtime::ub::UserBufferAllocator::Instance(); + auto* nccl_ub_allocator_ptr = dynamic_cast(&allocator); + TLLM_CHECK_WITH_INFO(nccl_ub_allocator_ptr != nullptr, + "NCCL_DEVICE requires the UBAllocator to be set up with the use_nccl option = True."); + auto& nccl_ub_allocator = *nccl_ub_allocator_ptr; + + auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); + + if (ub_buffer0.invalid()) + { + auto [symmetric_input, symmetric_ub_buffer0] + = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), + cudaMemcpyDeviceToDevice, stream); + ub_buffer0 = symmetric_ub_buffer0; + ub_tensor0 = symmetric_input; + } + TLLM_CHECK(!ub_buffer0.invalid()); + + auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + TLLM_CHECK(!ub_buffer1.invalid()); + // Get communicator size + auto& rawComm = std::get>(mNcclComm); + int nRanks; + ncclResult_t ncclError = ncclCommCount(*rawComm, &nRanks); + TLLM_CHECK_WITH_INFO( + ncclError == ncclSuccess, "Failed to get NCCL communicator size: %s", ncclGetErrorString(ncclError)); + + switch (mOp) + { + case AllReduceFusionOp::NONE: + NCCLCHECK(ncclAllReduce(ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], + ncclSum, *rawComm, stream)); + + return {norm_out}; + case AllReduceFusionOp::RESIDUAL_RMS_NORM: + { + TORCH_CHECK(norm_weight, "norm_weight is required for residual rms norm allreduce"); + TORCH_CHECK(residual, "residual is required for residual rms norm allreduce"); + TORCH_CHECK(!bias, "bias is not supported for residual rms norm allreduce"); + + int const hidden_size = input.size(-1); + int const num_tokens = size / hidden_size; + + TLLM_LOG_DEBUG("NCCL_DEVICE RESIDUAL_RMS_NORM: rank=%d, hidden_size=%d, num_tokens=%d, nRanks=%d, dtype=%d", + myRank, hidden_size, num_tokens, nRanks, static_cast(mType)); + + std::shared_ptr launchConfig + = nccl_ub_allocator.getCachedNCCLDeviceLaunchConfig( + mType, hidden_size, num_tokens, myRank, nRanks, true, false); + + // Check if multimem is supported for this data type + bool multimemSupported = launchConfig->supportsMultimem(); + TLLM_LOG_DEBUG("NCCL_DEVICE: rank=%d, supportsMultimem=%s", myRank, multimemSupported ? "true" : "false"); + + if (multimemSupported) + { + ncclWindow_t inWindow = ub_buffer0.window; + ncclWindow_t outWindow = ub_buffer1.window; + TLLM_CHECK(inWindow != nullptr); + TLLM_CHECK(outWindow != nullptr); + + auto [residual_out, ub_buffer2] + = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + TLLM_CHECK(!ub_buffer2.invalid()); + ncclDevComm devComm = nccl_ub_allocator.getNCCLDevComm(launchConfig->getNumSMs()); + + launchConfig->launchRMSNorm(inWindow, outWindow, residual.value().data_ptr(), ub_buffer2.window, + norm_weight.value().data_ptr(), nullptr, devComm, mEps, stream); + + TLLM_LOG_DEBUG("NCCL_DEVICE: rank=%d, fused kernel launched successfully", myRank); + return {norm_out, residual_out}; + } + // Fall back to old strategy with warning + if (myRank == 0) + { + TLLM_LOG_WARNING( + "[RANK 0] NCCL device Fused AR not supported for data type %d, hidden size %d & %d nRanks. " + "Check DEBUG logs from supportsMultimem() for reason. Falling back to standard allreduce + " + "separate RMSNorm.", + static_cast(mType), hidden_size, nRanks); + } + TLLM_LOG_WARNING( + "NCCL device Fused AR not supported for data type %d, hidden size %d & %d nRanks on current " + "architecture. Falling back to standard allreduce + separate RMSNorm.", + static_cast(mType), hidden_size, nRanks); + } + // Intentional fallthrough to default + default: + NCCLCHECK(ncclAllReduce( + ub_buffer0.addr, ub_buffer1.addr, size, (*getDtypeMap())[mType], ncclSum, *rawComm, stream)); + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out); + } + } + std::vector runLowPrecisionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) noexcept @@ -797,6 +913,79 @@ class AllreduceOp return {}; } + AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size) + { + AllReduceStrategyType runtime_strategy; + if (mStrategy == AllReduceStrategyType::UB) + { + runtime_strategy = AllReduceStrategyType::UB; + } + else if (mStrategy == AllReduceStrategyType::NCCL) + { + runtime_strategy = AllReduceStrategyType::NCCL; + } + else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + { + runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC; + } + else if (mStrategy == AllReduceStrategyType::NCCL_DEVICE) + { + runtime_strategy = AllReduceStrategyType::NCCL_DEVICE; + } + else + { + // This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set. + static char* ifForBenchMark = std::getenv("OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY"); + if (ifForBenchMark != nullptr) + { + runtime_strategy = mStrategy; + } + else + { + runtime_strategy = selectImplementation(seq_len, size); + } + } + return runtime_strategy; + } + + void logRunTimeStrategy(AllReduceStrategyType strategy, int rank) + { + switch (strategy) + { + case AllReduceStrategyType::NCCL: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank); + break; + } + case AllReduceStrategyType::NCCL_SYMMETRIC: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank); + break; + } + case AllReduceStrategyType::NCCL_DEVICE: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_DEVICE", rank); + break; + } + case AllReduceStrategyType::MIN_LATENCY: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank); + break; + } + case AllReduceStrategyType::UB: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank); + break; + } + case AllReduceStrategyType::LOWPRECISION: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank); + break; + } + default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break; + } + } + void initGroupTopology() { static std::map, std::tuple> cache; @@ -928,9 +1117,11 @@ class AllreduceOp { if (mStrategy != AllReduceStrategyType::AUTO) { - // For UB,NCCL,NCCL_SYMMETRIC, the correctness of the strategy dispatching is guaranteed by the user. + // For UB,NCCL,NCCL_SYMMETRIC,NCCL_DEVICE, the correctness of the strategy dispatching is guaranteed by the + // user. if (mStrategy == AllReduceStrategyType::UB || mStrategy == AllReduceStrategyType::NCCL - || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC + || mStrategy == AllReduceStrategyType::NCCL_DEVICE) { return mStrategy; } @@ -974,6 +1165,12 @@ class AllreduceOp // If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type. if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported) { + auto const rank = getRank(); + if (rank == 0) + { + TLLM_LOG_INFO("[RANK 0] Fallback to NCCL: msg_size_bytes=%zu, max_workspace=%zu, P2P=%d, NVLINK=%d", + message_size_bytes, max_workspace_size, mIsP2PSupported, mIsNVLINKSupported); + } return true; } diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 9278409aee0..f3103c186ec 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -193,7 +193,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"): "TWOSHOT": AllReduceStrategy.TWOSHOT, "LOWPRECISION": AllReduceStrategy.LOWPRECISION, "MNNVL": AllReduceStrategy.MNNVL, - "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC, + "NCCL_DEVICE": AllReduceStrategy.NCCL_DEVICE, } key = strategy.upper() return maps[key] if key in maps else AllReduceStrategy.AUTO diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e6da9fc216a..959137fdb7f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -262,9 +262,11 @@ def __init__( ] try: - use_ub_for_nccl = ( - self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC" - and self._init_userbuffers(self.model.config.hidden_size)) + + use_ub_for_nccl = (self.llm_args.allreduce_strategy + in ("NCCL_SYMMETRIC", "NCCL_DEVICE") + and self._init_userbuffers( + self.model.config.hidden_size)) if self._torch_compile_enabled: set_torch_compiling(True) use_ub = not use_ub_for_nccl and ( @@ -2477,7 +2479,8 @@ def _init_userbuffers(self, hidden_size): # Disable UB for unsupported platforms if not ub.ub_supported(): return False - use_nccl_symmetric = self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC" + use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy in ( + "NCCL_SYMMETRIC", "NCCL_DEVICE") ub.initialize_userbuffers_manager( self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size, self.mapping.rank, self.mapping.gpus_per_node, diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 282febd262e..e372dca1ead 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum): LOWPRECISION = 6 MNNVL = 7 NCCL_SYMMETRIC = 8 + NCCL_DEVICE = 9 class AllReduceFusionOp(IntEnum): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index e002c42ddd1..0d2fa09346a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2537,10 +2537,10 @@ class TorchLlmArgs(BaseLlmArgs): allreduce_strategy: Optional[Literal[ 'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', - 'LOWPRECISION', 'MNNVL', - 'NCCL_SYMMETRIC']] = Field(default='AUTO', - description="Allreduce strategy to use.", - status="beta") + 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC', + 'NCCL_DEVICE']] = Field(default='AUTO', + description="Allreduce strategy to use.", + status="beta") checkpoint_loader: Optional[object] = Field( default=None, description= diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index 837b0348129..a86651e0ab4 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -26,6 +26,7 @@ from cuda import cudart import tensorrt_llm as tllm +import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm import Mapping from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp from tensorrt_llm._torch.modules.rms_norm import RMSNorm @@ -33,7 +34,6 @@ nvtx_range) from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy -from tensorrt_llm.plugin.plugin import CustomAllReduceHelper def profile_allreduce( @@ -49,8 +49,6 @@ def profile_allreduce( scale=None, bias=None, ): - tllm.logger.set_level('error') - allreduce_params = AllReduceParams( fusion_op=fusion, residual=residual, @@ -122,40 +120,85 @@ def allreduce_benchmark( enable_cudagraph: bool = False, explore_2d: bool = False, save_csv: str = None, - enable_auto: bool = False, + strategy: str = None, + inner_loop: int = 200, + outer_loop: int = 10, + tokens_range: str = "1,16384,2", + hidden_sizes_range: str = "128,8192,2", ): - tllm.logger.set_level('error') + """ + Benchmark AllReduce operations. + + Args: + dtype: Data type for benchmarking + test_range: Range specification (min,max,ratio) + enable_cudagraph: Enable CUDA graph capture + explore_2d: Explore 2D parameter space (num_tokens x hidden_size) + save_csv: Path to save CSV results + strategy: Specific strategy to test (if None, tests default set: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL) + inner_loop: Number of iterations per timing measurement (default: 200) + outer_loop: Number of timing measurements to take (default: 10) + tokens_range: Range for number of tokens in 2D mode (min,max,ratio) (default: "1,16384,2") + hidden_sizes_range: Range for hidden sizes in 2D mode (min,max,ratio) (default: "128,8192,2") + """ world_size = tllm.mpi_world_size() rank = tllm.mpi_rank() local_rank = local_mpi_rank() gpus_per_node = local_mpi_size() + if world_size == 1: + if rank == 0: + print("ERROR: Benchmark must run with mpi_world_size > 1", + file=sys.stderr, + flush=True) + sys.exit(1) + + # Device setup torch.cuda.set_device(local_rank) cudart.cudaSetDevice(local_rank) mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) sm_version = get_sm_version() - if world_size == 1: - raise RuntimeError("Benchmark must run with mpi_world_size > 1") - + # Data type setup torch_dtype = tllm._utils.str_dtype_to_torch(dtype) + dtype_size_bytes = torch_dtype.itemsize - inner_loop = 200 - outer_loop = 10 + # Parse test range + min_size, max_size, ratio = [int(i) for i in test_range.split(",")] # generate shape list shape_list = [] if explore_2d: - num_seqs_list = [ - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + # Parse tokens range + min_tokens, max_tokens, tokens_ratio = [ + int(i) for i in tokens_range.split(",") + ] + + # Parse hidden sizes range + min_hidden, max_hidden, hidden_ratio = [ + int(i) for i in hidden_sizes_range.split(",") ] - hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192] + + # Generate token counts list + num_seqs_list = [] + current = min_tokens + while current <= max_tokens: + num_seqs_list.append(current) + current *= tokens_ratio + + # Generate hidden sizes list + hidden_size_list = [] + current = min_hidden + while current <= max_hidden: + hidden_size_list.append(current) + current *= hidden_ratio + + # Create all combinations for num_tokens, hidden_size in product(num_seqs_list, hidden_size_list): shape_list.append((num_tokens, hidden_size)) else: - min_size, max_size, ratio = [int(i) for i in test_range.split(",")] size = min_size hidden_size = min_size num_tokens = 1 @@ -174,21 +217,80 @@ def allreduce_benchmark( AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, ] - strategies = [ - AllReduceStrategy.NCCL, - AllReduceStrategy.ONESHOT, - AllReduceStrategy.TWOSHOT, - AllReduceStrategy.AUTO, - ] + + # Map strategy names to enum values + strategy_map = { + "NCCL": AllReduceStrategy.NCCL, + "MIN_LATENCY": AllReduceStrategy.MIN_LATENCY, + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC, + "NCCL_DEVICE": AllReduceStrategy.NCCL_DEVICE, + "MNNVL": AllReduceStrategy.MNNVL, + "UB": AllReduceStrategy.UB, + "ONESHOT": AllReduceStrategy.ONESHOT, + "TWOSHOT": AllReduceStrategy.TWOSHOT, + "AUTO": AllReduceStrategy.AUTO, + } + + # Select strategies based on input + if strategy: + # Single strategy specified + if strategy.upper() not in strategy_map: + raise ValueError( + f"Unknown strategy: {strategy}. Available: {', '.join(strategy_map.keys())}" + ) + strategies = [strategy_map[strategy.upper()]] + else: + # Default: test main strategies + strategies = [ + AllReduceStrategy.NCCL, + AllReduceStrategy.NCCL_SYMMETRIC, + AllReduceStrategy.NCCL_DEVICE, + AllReduceStrategy.MNNVL, + ] + + # Validate strategy compatibility for user buffer initialization + # NCCL_SYMMETRIC and NCCL_DEVICE need UB with use_multicast=True + # UB strategy needs UB with use_multicast=False + # These two groups cannot be mixed in a single run + ub_multicast_strategies = { + AllReduceStrategy.NCCL_SYMMETRIC, AllReduceStrategy.NCCL_DEVICE + } + ub_no_multicast_strategies = {AllReduceStrategy.UB} + + has_multicast_strategies = any(s in ub_multicast_strategies + for s in strategies) + has_no_multicast_strategies = any(s in ub_no_multicast_strategies + for s in strategies) + + # Error out if incompatible strategies are mixed + if has_multicast_strategies and has_no_multicast_strategies: + multicast_strats = [ + s.name for s in strategies if s in ub_multicast_strategies + ] + no_multicast_strats = [ + s.name for s in strategies if s in ub_no_multicast_strategies + ] + raise ValueError( + f"Incompatible strategies selected: {multicast_strats} require use_multicast=True " + f"while {no_multicast_strats} require use_multicast=False. " + f"Please run these strategies separately using --strategy.") + + # Initialize user buffers if any strategy needs it + needs_ub = has_multicast_strategies or has_no_multicast_strategies + + if needs_ub: + max_bytes = max_size * dtype_size_bytes + use_multicast = has_multicast_strategies # True for NCCL_SYMMETRIC/NCCL_DEVICE, False for UB + + ub.initialize_userbuffers_manager(world_size, 1, 1, rank, + torch.cuda.device_count(), max_bytes, + use_multicast) + df = pd.DataFrame() for (num_tokens, hidden_size) in shape_list: message_size = num_tokens * hidden_size * torch.finfo( torch_dtype).bits // 8 - if message_size > CustomAllReduceHelper.max_workspace_size_auto( - mapping.tp_size): - continue - input = torch.ones((num_tokens, hidden_size), dtype=torch_dtype, device="cuda") @@ -208,7 +310,8 @@ def allreduce_benchmark( if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 and sm_version < 100: continue - if not enable_auto and strategy == AllReduceStrategy.AUTO: + # UB strategy doesn't support NONE fusion + if strategy == AllReduceStrategy.UB and fusion == AllReduceFusionOp.NONE: continue median_ms = profile_allreduce( @@ -242,6 +345,9 @@ def allreduce_benchmark( # print the dataframe if mapping.rank == 0: pd.set_option('display.max_rows', None) + pd.set_option('display.max_columns', None) + pd.set_option('display.width', None) + pd.set_option('display.max_colwidth', None) print(df) # # save the dataframe to a csv file @@ -253,24 +359,71 @@ def allreduce_benchmark( if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--dtype", "-t", default="bfloat16") + parser.add_argument("--dtype", + "-t", + default="bfloat16", + help="Data type for benchmarking") parser.add_argument( "--range", "-r", - default="256,256000000,10", # 256 to 256M + default="256,256000000,4", # 256 to 256M help="min_size,max_size,multiplicative_ratio") - parser.add_argument("--explore_2d", action="store_true", default=False) - parser.add_argument("--enable_cudagraph", action="store_true") - parser.add_argument("--save_csv", type=str, default=None) - parser.add_argument("--enable_auto", action="store_true", default=False) + parser.add_argument( + "--explore_2d", + action="store_true", + default=False, + help="Explore 2D parameter space (num_tokens x hidden_size)") + parser.add_argument("--enable_cudagraph", + action="store_true", + help="Enable CUDA graph capture") + parser.add_argument("--save_csv", + type=str, + default=None, + help="Path to save CSV results") + parser.add_argument( + "--strategy", + type=str, + default=None, + help= + "Test specific strategy. If not specified, defaults to: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL. " + "Available: NCCL, NCCL_SYMMETRIC, NCCL_DEVICE, MNNVL, MIN_LATENCY, UB, ONESHOT, TWOSHOT, AUTO" + ) + parser.add_argument( + "--inner_loop", + type=int, + default=200, + help="Number of iterations per timing measurement (default: 200)") + parser.add_argument( + "--outer_loop", + type=int, + default=10, + help="Number of timing measurements to take (default: 10)") + parser.add_argument( + "--tokens_range", + type=str, + default="1,16384,2", + help= + "Range for number of tokens in 2D mode: min,max,ratio (default: 1,16384,2)" + ) + parser.add_argument( + "--hidden_sizes_range", + type=str, + default="128,8192,2", + help= + "Range for hidden sizes in 2D mode: min,max,ratio (default: 128,8192,2)" + ) args = parser.parse_args() allreduce_benchmark( - args.dtype, - args.range, - args.enable_cudagraph, - args.explore_2d, - args.save_csv, - args.enable_auto, + dtype=args.dtype, + test_range=args.range, + enable_cudagraph=args.enable_cudagraph, + explore_2d=args.explore_2d, + save_csv=args.save_csv, + strategy=args.strategy, + inner_loop=args.inner_loop, + outer_loop=args.outer_loop, + tokens_range=args.tokens_range, + hidden_sizes_range=args.hidden_sizes_range, ) diff --git a/tests/unittest/_torch/multi_gpu/test_nccl_device.py b/tests/unittest/_torch/multi_gpu/test_nccl_device.py new file mode 100644 index 00000000000..27f6aafe5e8 --- /dev/null +++ b/tests/unittest/_torch/multi_gpu/test_nccl_device.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle +import sys +import traceback + +import cloudpickle +import pytest +import torch +from mpi4py import MPI + +import tensorrt_llm +import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.distributed import ( + AllReduce, + AllReduceFusionOp, + AllReduceParams, + AllReduceStrategy, +) +from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm.mapping import Mapping + +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + +# needed since we reuse the mpi executor pool, first test running will leak a thread +pytestmark = pytest.mark.threadleak(enabled=False) + + +def init_userbuffers_allocator(tp_size, rank, max_ub_size): + ub.initialize_userbuffers_manager( + tp_size, 1, 1, rank, torch.cuda.device_count(), max_ub_size, True + ) + + +def create_userbuffers_tensor(shape, dtype): + # WAR pickle error + def func(shape, dtype): + return torch.ops.trtllm.create_userbuffers_tensor(shape, dtype) + + return func(shape, dtype) + + +# This rms_norm aligns with ub impl that calculate gamma * hidden in high +# precision +def rms_norm(input, gamma, eps): + variance = input.pow(2).mean(-1, keepdim=True) + hidden_states = input * torch.rsqrt(variance + eps) + return gamma.to(torch.float32) * hidden_states + + +def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma): + rank = tensorrt_llm.mpi_rank() + + # Set CUDA device BEFORE any CUDA operations or NCCL initialization + torch.cuda.set_device(rank) + + # Ensure CUDA context is properly initialized for this device + torch.cuda.synchronize() + + try: + support = ub.ub_supported() + if not support: + return True + eps = 1e-6 + + # Split tensors for tensor parallelism - ensure equal sizes + k_chunk_size = a.size(1) // tensor_parallel_size + + # Ensure we get exactly tensor_parallel_size chunks + a_partial = torch.split(a, k_chunk_size, dim=1) + b_partial = torch.split(b, k_chunk_size, dim=0) + + a_local = a_partial[rank].cuda() + b_local = b_partial[rank].cuda() + c = c.cuda() + gamma = gamma.cuda() + + ub_size = c.nelement() * c.element_size() + init_userbuffers_allocator(tensor_parallel_size, rank, ub_size) + + ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype) + hidden = torch.matmul(a_local, b_local, out=ub0_tensor) + + # Add barrier to ensure all MPI processes are ready before NCCL initialization + if ENABLE_MULTI_DEVICE: + tensorrt_llm._utils.mpi_barrier() + + # Ensure all ranks have set their CUDA devices before creating AllReduce + if ENABLE_MULTI_DEVICE: + tensorrt_llm._utils.mpi_barrier() + + mapping = Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=rank, + ) + ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.NCCL_DEVICE) + ar_params = AllReduceParams( + strategy=AllReduceStrategy.NCCL_DEVICE, + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=c, + norm_weight=gamma, + eps=eps, + ) + + res_ub, residual = ar.forward(hidden, all_reduce_params=ar_params) + res = res_ub.clone() + + torch.cuda.synchronize() + # Fully simulate matmul + allreduce behavior + ax = [a_partial[i].cuda() for i in range(0, tensor_parallel_size)] + bx = [b_partial[i].cuda() for i in range(0, tensor_parallel_size)] + h1 = [torch.matmul(ax[i], bx[i]) for i in range(0, tensor_parallel_size)] + sum = h1[0] + for i in range(1, tensor_parallel_size): + sum = sum + h1[i] + ref_residual = sum + c + ref = rms_norm(ref_residual.to(torch.float32), gamma, eps).to(res.dtype) + torch.testing.assert_close(ref, res, atol=5e-1, rtol=1e-2) + + # Current production version performs full scatter also for the + # residual, so we can compare unchunked + chunked_residual_comparison = False + + # Since we do not always perform an AllGather of the residual, + # let's compare on every rank the right portions of the residual + residual_chunk_size = ref_residual.size(0) // tensor_parallel_size + if ref_residual.size(0) % tensor_parallel_size != 0: + residual_chunk_size += 1 + chunk_start = rank * residual_chunk_size + chunk_end = min((rank + 1) * residual_chunk_size, ref_residual.size(0)) + + # If we do perform the AllGather implicitly we can compare the entire tensor. + if not chunked_residual_comparison: + chunk_start = 0 + chunk_end = ref_residual.size(0) + ref_residual = ref_residual[chunk_start:chunk_end] + residual = residual[chunk_start:chunk_end] + + torch.testing.assert_close(ref_residual, residual, atol=5e-1, rtol=1e-2) + + except Exception: + traceback.print_exc() + raise + return True + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") +@pytest.mark.parametrize( + "mnk", [(128, 8192, 64), (79, 512, 32)], ids=lambda x: f"m{x[0]}_n{x[1]}_k{x[2]}" +) +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_user_buffers_ar_rms_norm(mnk, mpi_pool_executor): + # Ensure all MPI processes are synchronized before any test execution + if ENABLE_MULTI_DEVICE: + tensorrt_llm._utils.mpi_barrier() + + torch.manual_seed(42) + tensor_parallel_size = 2 + dtype = torch.float16 + m = mnk[0] + n = mnk[1] + k = mnk[2] + + # Ensure tensor dimensions are compatible with 2-way tensor parallelism + assert k % tensor_parallel_size == 0, ( + f"k dimension {k} must be divisible by tensor_parallel_size {tensor_parallel_size}" + ) + assert n % tensor_parallel_size == 0, ( + f"n dimension {n} must be divisible by tensor_parallel_size {tensor_parallel_size}" + ) + + a = torch.randn((m, k), dtype=dtype) + b = torch.randn((k, n), dtype=dtype) + c = torch.randn((m, n), dtype=dtype) + gamma = torch.randn((n), dtype=dtype) + + results = mpi_pool_executor.map( + run_single_rank_ar_rms_norm, + *zip(*[(tensor_parallel_size, a, b, c, gamma)] * tensor_parallel_size), + ) + for r in results: + assert r is True diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 9cd5ee37348..1c4161e2bbc 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -164,7 +164,7 @@ methods: default: False status: prototype allreduce_strategy: - annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC']] + annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC', 'NCCL_DEVICE']] default: AUTO status: beta decoding_config: