From e74eb6d58f288f13108a30804a96a28382e43e75 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 28 Aug 2025 17:29:32 -0700 Subject: [PATCH 1/6] Fix logic for communicator split Signed-off-by: Shiyu Li --- cpp/tensorrt_llm/pybind/runtime/bindings.cpp | 2 +- .../runtime/mcastDeviceMemory.cpp | 28 +++++++++---------- cpp/tensorrt_llm/runtime/mcastDeviceMemory.h | 6 +++- cpp/tensorrt_llm/runtime/mcastGPUBuffer.h | 6 ++-- tensorrt_llm/_torch/distributed/ops.py | 11 ++++++-- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 574249b6a23..7966e41b405 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -455,7 +455,7 @@ void initBindings(pybind11::module_& m) py::call_guard()); py::class_(m, "McastGPUBuffer") - .def(py::init(), py::call_guard()) + .def(py::init(), py::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, py::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp index 950215e7542..841579ef012 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp @@ -20,7 +20,7 @@ #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/runtime/utils/mpiUtils.h" + #include #include #include @@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran) } // namespace McastDeviceMemory::McastDeviceMemory( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink) + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink) : mIsMNNvlink(mnNvlink) , mDeviceIdx(deviceIdx) , mGroupSize(groupSize) @@ -62,9 +62,14 @@ McastDeviceMemory::McastDeviceMemory( // From pytorch implementation for alignment constexpr size_t kSignalPadAlignment = 16UL; mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment); + // Initialize the MPI communicator for this group + mGroupComm = tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank); + int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()}; + TLLM_LOG_DEBUG( - "[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu", - mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); + "[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, " + "device_idx: %d, Signal pad offset: %zu", + world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); if (mIsMNNvlink) { @@ -127,9 +132,6 @@ McastDeviceMemory::~McastDeviceMemory() void McastDeviceMemory::allocMnMcastMem(size_t bufSize) { - - auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session(); - CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; CUmemAllocationProp prop = {}; prop.requestedHandleTypes = handle_type; @@ -156,7 +158,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) // All gather cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle)); memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle)); - mpi_comm.allgather( + mGroupComm.allgather( exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR); cudaDeviceSynchronize(); @@ -175,7 +177,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); } // Broadcast - mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0); + mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0); cudaDeviceSynchronize(); if (mGroupRank != 0) { @@ -210,12 +212,8 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize) { - // Create a std::set to include all ranks in range (0, group_size) - std::set ranks; - for (uint32_t i = 0; i < mGroupSize; ++i) - { - ranks.insert(i); - } + // Get the world ranks for ranks in this group + std::set ranks{tensorrt_llm::mpi::getWorldRanks(mGroupComm)}; // Reuse existing implementation mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks); mMcHandle = mNvlsHandle->mc_handle; diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h index 4afcc05223d..219a323a6a0 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h @@ -17,6 +17,7 @@ #include "tensorrt_llm/common/mcastDevMemUtils.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include @@ -42,7 +43,8 @@ class McastDeviceMemory McastDeviceMemory(McastDeviceMemory const&) = delete; McastDeviceMemory& operator=(McastDeviceMemory const&) = delete; - McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink); + McastDeviceMemory( + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink); // We don't register the pointer in these two functions since we don't expect any python-level code would call // to obtain the raw pointers. @@ -94,6 +96,8 @@ class McastDeviceMemory size_t mSignalPadOffset; size_t mAllocationSize; + tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group + CUdeviceptr mMcPtr; CUmemGenericAllocationHandle mMcHandle; std::vector mUcHandles; diff --git a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h index 941ddb1a46a..9fd03de4519 100644 --- a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h +++ b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h @@ -34,10 +34,12 @@ class McastGPUBuffer //! \param bufSize The total size of the buffer in bytes. //! \param groupSize The number of ranks in the communication group. //! \param groupRank The rank of the local process within the group. + //! \param splitColor The color of the split for topology split. //! \param device The CUDA device for buffer allocation. //! \param mnNvlink Flag indicating if multi-node NVLink is used. - McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink) - : mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink) + McastGPUBuffer( + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, at::Device device, bool mnNvlink) + : mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, device.index(), mnNvlink) , mBufSize(bufSize) , mLocalDevice(device) { diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index b3811204dfa..c643bc86f1f 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -8,7 +8,7 @@ import torch from torch import nn -from tensorrt_llm._utils import mpi_barrier +from tensorrt_llm._utils import mpi_comm from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams) @@ -55,11 +55,14 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None: def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]: + if not hasattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'): setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', {}) - + # Support topology split + comm = mpi_comm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, + mapping.tp_rank) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" allreduce_mnnvl_workspaces = getattr( @@ -77,6 +80,8 @@ def get_allreduce_mnnvl_workspace( buffer_size_in_bytes, mapping.tp_size, mapping.tp_rank, + # Split the communicator according to the topology + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, torch.device("cuda", mapping.local_rank), True, # mnNvlink ) @@ -87,7 +92,7 @@ def get_allreduce_mnnvl_workspace( buffer.fill_(-0.0) # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + comm.Barrier() # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer From 7ae720cb84eaa713016e3534e26737fcdf4257f7 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 29 Aug 2025 11:48:10 -0700 Subject: [PATCH 2/6] Fix build error. Signed-off-by: Shiyu Li --- cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp | 6 +++--- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp index 841579ef012..9be590c7fce 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp @@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory( , mAllocationSize(0) , mMcPtr(0) , mMcHandle(0) + , mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank)) { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx)); @@ -62,8 +63,6 @@ McastDeviceMemory::McastDeviceMemory( // From pytorch implementation for alignment constexpr size_t kSignalPadAlignment = 16UL; mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment); - // Initialize the MPI communicator for this group - mGroupComm = tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank); int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()}; TLLM_LOG_DEBUG( @@ -213,7 +212,8 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize) { // Get the world ranks for ranks in this group - std::set ranks{tensorrt_llm::mpi::getWorldRanks(mGroupComm)}; + auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm); + std::set ranks(ranks_.begin(), ranks_.end()); // Reuse existing implementation mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks); mMcHandle = mNvlsHandle->mc_handle; diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0ca4d28085b..2f369d6eee8 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -771,8 +771,7 @@ def _compute_mlp_tp_size(self, intermediate_size: int, self.mapping.tp_size, ) - if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl( - ): + if tp > self.mapping.gpus_per_node: mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, From 458f6d4c6a685fdaf668e9d8192c9d3da8c49669 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 29 Aug 2025 11:48:56 -0700 Subject: [PATCH 3/6] Fix nanobind error. Signed-off-by: Shiyu Li --- cpp/tensorrt_llm/nanobind/runtime/bindings.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 5cdab7ba7e0..cb3852a18db 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -301,9 +301,8 @@ void initBindings(nb::module_& m) .def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"), nb::arg("sampling_config"), nb::arg("streaming"), nb::call_guard()) .def_prop_ro( - "decoder_stream", - [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, - nb::rv_policy::reference); + "decoder_stream", [](tr::GptDecoderBatched& self) -> tr::CudaStream const& + { return *self.getDecoderStream(); }, nb::rv_policy::reference); m.def( "lamport_initialize_all", @@ -314,8 +313,7 @@ void initBindings(nb::module_& m) }, "Lamport initialize all buffers", nb::call_guard()); m.def( - "lamport_initialize", - [](intptr_t buffer, size_t size) + "lamport_initialize", [](intptr_t buffer, size_t size) { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, "Lmaport initialize buffer", nb::call_guard()); m.def( @@ -361,7 +359,8 @@ void initBindings(nb::module_& m) nb::call_guard()); nb::class_(m, "McastGPUBuffer") - .def(nb::init(), nb::call_guard()) + .def(nb::init(), + nb::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, nb::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, From 9af99ec26ab6d5ab13c9ae91e6e5881de2381b23 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 29 Aug 2025 20:22:43 -0700 Subject: [PATCH 4/6] Fix binding errors. Signed-off-by: Shiyu Li --- .../nanobind/runtime/bindings.cpp | 5 +++-- cpp/tensorrt_llm/pybind/runtime/bindings.cpp | 12 +++++------ cpp/tensorrt_llm/runtime/mcastGPUBuffer.h | 20 ++++++++++++------- tensorrt_llm/_torch/distributed/ops.py | 16 +++++---------- .../_torch/models/modeling_deepseekv3.py | 2 +- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index cb3852a18db..262226031dc 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -359,8 +359,9 @@ void initBindings(nb::module_& m) nb::call_guard()); nb::class_(m, "McastGPUBuffer") - .def(nb::init(), - nb::call_guard()) + .def(nb::init(), nb::arg("buf_size"), + nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"), + nb::arg("mn_nvlink"), nb::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, nb::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 7966e41b405..b65a3824db9 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -395,9 +395,8 @@ void initBindings(pybind11::module_& m) .def("finalize", &tr::GptDecoderBatched::finalize, py::arg("decoder_state"), py::arg("batch_idx"), py::arg("sampling_config"), py::arg("streaming"), py::call_guard()) .def_property_readonly( - "decoder_stream", - [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, - py::return_value_policy::reference); + "decoder_stream", [](tr::GptDecoderBatched& self) -> tr::CudaStream const& + { return *self.getDecoderStream(); }, py::return_value_policy::reference); m.def( "lamport_initialize_all", @@ -408,8 +407,7 @@ void initBindings(pybind11::module_& m) }, "Lamport initialize all buffers", py::call_guard()); m.def( - "lamport_initialize", - [](intptr_t buffer, size_t size) + "lamport_initialize", [](intptr_t buffer, size_t size) { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, "Lmaport initialize buffer", py::call_guard()); m.def( @@ -455,7 +453,9 @@ void initBindings(pybind11::module_& m) py::call_guard()); py::class_(m, "McastGPUBuffer") - .def(py::init(), py::call_guard()) + .def(py::init(), py::arg("buf_size"), + py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"), + py::arg("mn_nvlink"), py::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, py::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h index 9fd03de4519..4c011a790ba 100644 --- a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h +++ b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h @@ -38,10 +38,10 @@ class McastGPUBuffer //! \param device The CUDA device for buffer allocation. //! \param mnNvlink Flag indicating if multi-node NVLink is used. McastGPUBuffer( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, at::Device device, bool mnNvlink) - : mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, device.index(), mnNvlink) + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink) + : mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink) , mBufSize(bufSize) - , mLocalDevice(device) + , mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx)) { } @@ -51,7 +51,7 @@ class McastGPUBuffer //! \param dtype The data type of the tensor elements. //! \param storageOffset The offset in elements from the start of the buffer. //! \return An ATen tensor wrapping the unicast buffer section. - at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset) + at::Tensor getUCBuffer(uint32_t rank, std::vector sizes, torch::ScalarType dtype, int64_t storageOffset) { size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies()); size_t const elementSize = c10::elementSize(dtype); @@ -61,7 +61,10 @@ class McastGPUBuffer auto* dataPtr = static_cast(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize; auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice); - return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor(); + return at::for_blob(dataPtr, c10::IntArrayRef(sizes)) + .options(options) + .target_device(mLocalDevice) + .make_tensor(); } //! \brief Returns a PyTorch tensor view of the multicast buffer portion. @@ -69,7 +72,7 @@ class McastGPUBuffer //! \param dtype The data type of the tensor elements. //! \param storageOffset The offset in elements from the start of the buffer. //! \return An ATen tensor wrapping the multicast buffer section. - at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset) + at::Tensor getMCBuffer(std::vector sizes, torch::ScalarType dtype, int64_t storageOffset) { size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies()); size_t const elementSize = c10::elementSize(dtype); @@ -79,7 +82,10 @@ class McastGPUBuffer auto* dataPtr = static_cast(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize; auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice); - return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor(); + return at::for_blob(dataPtr, c10::IntArrayRef(sizes)) + .options(options) + .target_device(mLocalDevice) + .make_tensor(); } private: diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index c643bc86f1f..c5749681040 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1,4 +1,3 @@ -import logging import math import os import platform @@ -17,7 +16,6 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper _thread_local = threading.local() -logger = logging.getLogger(__name__) def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: @@ -61,8 +59,9 @@ def get_allreduce_mnnvl_workspace( setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', {}) # Support topology split - comm = mpi_comm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, - mapping.tp_rank) + comm = mpi_comm().Split( + int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank), + mapping.tp_rank) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" allreduce_mnnvl_workspaces = getattr( @@ -82,7 +81,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, # Split the communicator according to the topology mapping.pp_rank * mapping.cp_size + mapping.cp_rank, - torch.device("cuda", mapping.local_rank), + mapping.local_rank, True, # mnNvlink ) @@ -463,12 +462,7 @@ def __init__(self, # Initialize MNNVL AllReduce if needed if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): - if self.mapping.tp_size != self.mapping.world_size: - logger.debug( - f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} " - f"!= world_size:{self.mapping.world_size}") - self.mnnvl_allreduce = None - elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): + if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: self.mnnvl_allreduce = MNNVLAllReduce( self.mapping, dtype) if dtype else None diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 2f369d6eee8..09b42c6fee4 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -775,7 +775,7 @@ def _compute_mlp_tp_size(self, intermediate_size: int, mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, - ) # Avoid costly inter-node TP when MNNVL is not supported + ) # Avoid costly inter-node TP else: mlp_tp_size = tp return mlp_tp_size From 0042f0825267f7f76dca86a81b2fe1d65ecde341 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Mon, 1 Sep 2025 19:44:03 -0700 Subject: [PATCH 5/6] Fix build error. Signed-off-by: Shiyu Li --- cpp/tensorrt_llm/runtime/mcastDeviceMemory.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h index 219a323a6a0..d9428b4126c 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h @@ -96,12 +96,12 @@ class McastDeviceMemory size_t mSignalPadOffset; size_t mAllocationSize; - tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group - CUdeviceptr mMcPtr; CUmemGenericAllocationHandle mMcHandle; std::vector mUcHandles; + tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group + // Host array of pointers std::vector mUcPtrs; std::vector mSignalPads; From b98a753b27f9f9ccd8f3d33aafcff87e8de55071 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Mon, 15 Sep 2025 09:47:45 -0700 Subject: [PATCH 6/6] Fix wrong format during rebasing. Signed-off-by: Shiyu Li --- cpp/tensorrt_llm/nanobind/runtime/bindings.cpp | 8 +++++--- cpp/tensorrt_llm/pybind/runtime/bindings.cpp | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 262226031dc..388819b957a 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -301,8 +301,9 @@ void initBindings(nb::module_& m) .def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"), nb::arg("sampling_config"), nb::arg("streaming"), nb::call_guard()) .def_prop_ro( - "decoder_stream", [](tr::GptDecoderBatched& self) -> tr::CudaStream const& - { return *self.getDecoderStream(); }, nb::rv_policy::reference); + "decoder_stream", + [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, + nb::rv_policy::reference); m.def( "lamport_initialize_all", @@ -313,7 +314,8 @@ void initBindings(nb::module_& m) }, "Lamport initialize all buffers", nb::call_guard()); m.def( - "lamport_initialize", [](intptr_t buffer, size_t size) + "lamport_initialize", + [](intptr_t buffer, size_t size) { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, "Lmaport initialize buffer", nb::call_guard()); m.def( diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index b65a3824db9..469aafe6476 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -395,8 +395,9 @@ void initBindings(pybind11::module_& m) .def("finalize", &tr::GptDecoderBatched::finalize, py::arg("decoder_state"), py::arg("batch_idx"), py::arg("sampling_config"), py::arg("streaming"), py::call_guard()) .def_property_readonly( - "decoder_stream", [](tr::GptDecoderBatched& self) -> tr::CudaStream const& - { return *self.getDecoderStream(); }, py::return_value_policy::reference); + "decoder_stream", + [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, + py::return_value_policy::reference); m.def( "lamport_initialize_all", @@ -407,7 +408,8 @@ void initBindings(pybind11::module_& m) }, "Lamport initialize all buffers", py::call_guard()); m.def( - "lamport_initialize", [](intptr_t buffer, size_t size) + "lamport_initialize", + [](intptr_t buffer, size_t size) { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, "Lmaport initialize buffer", py::call_guard()); m.def(