Skip to content

Commit 9af99ec

Browse files
committed
Fix binding errors.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 458f6d4 commit 9af99ec

File tree

5 files changed

+28
-27
lines changed

5 files changed

+28
-27
lines changed

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,9 @@ void initBindings(nb::module_& m)
359359
nb::call_guard<nb::gil_scoped_release>());
360360

361361
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
362-
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, at::Device, bool>(),
363-
nb::call_guard<nb::gil_scoped_release>())
362+
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
363+
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
364+
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
364365
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
365366
nb::call_guard<nb::gil_scoped_release>())
366367
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,8 @@ void initBindings(pybind11::module_& m)
395395
.def("finalize", &tr::GptDecoderBatched::finalize, py::arg("decoder_state"), py::arg("batch_idx"),
396396
py::arg("sampling_config"), py::arg("streaming"), py::call_guard<py::gil_scoped_release>())
397397
.def_property_readonly(
398-
"decoder_stream",
399-
[](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); },
400-
py::return_value_policy::reference);
398+
"decoder_stream", [](tr::GptDecoderBatched& self) -> tr::CudaStream const&
399+
{ return *self.getDecoderStream(); }, py::return_value_policy::reference);
401400

402401
m.def(
403402
"lamport_initialize_all",
@@ -408,8 +407,7 @@ void initBindings(pybind11::module_& m)
408407
},
409408
"Lamport initialize all buffers", py::call_guard<py::gil_scoped_release>());
410409
m.def(
411-
"lamport_initialize",
412-
[](intptr_t buffer, size_t size)
410+
"lamport_initialize", [](intptr_t buffer, size_t size)
413411
{ tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast<void*>(buffer), size, 0); },
414412
"Lmaport initialize buffer", py::call_guard<py::gil_scoped_release>());
415413
m.def(
@@ -455,7 +453,9 @@ void initBindings(pybind11::module_& m)
455453
py::call_guard<py::gil_scoped_release>());
456454

457455
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
458-
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, at::Device, bool>(), py::call_guard<py::gil_scoped_release>())
456+
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
457+
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
458+
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
459459
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
460460
py::call_guard<py::gil_scoped_release>())
461461
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

cpp/tensorrt_llm/runtime/mcastGPUBuffer.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class McastGPUBuffer
3838
//! \param device The CUDA device for buffer allocation.
3939
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
4040
McastGPUBuffer(
41-
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, at::Device device, bool mnNvlink)
42-
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, device.index(), mnNvlink)
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
42+
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
4343
, mBufSize(bufSize)
44-
, mLocalDevice(device)
44+
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
4545
{
4646
}
4747

@@ -51,7 +51,7 @@ class McastGPUBuffer
5151
//! \param dtype The data type of the tensor elements.
5252
//! \param storageOffset The offset in elements from the start of the buffer.
5353
//! \return An ATen tensor wrapping the unicast buffer section.
54-
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
54+
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
5555
{
5656
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
5757
size_t const elementSize = c10::elementSize(dtype);
@@ -61,15 +61,18 @@ class McastGPUBuffer
6161
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
6262

6363
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
64-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
64+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
65+
.options(options)
66+
.target_device(mLocalDevice)
67+
.make_tensor();
6568
}
6669

6770
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
6871
//! \param sizes The desired shape (dimensions) of the tensor.
6972
//! \param dtype The data type of the tensor elements.
7073
//! \param storageOffset The offset in elements from the start of the buffer.
7174
//! \return An ATen tensor wrapping the multicast buffer section.
72-
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
75+
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
7376
{
7477
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
7578
size_t const elementSize = c10::elementSize(dtype);
@@ -79,7 +82,10 @@ class McastGPUBuffer
7982
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
8083

8184
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
82-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
85+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
86+
.options(options)
87+
.target_device(mLocalDevice)
88+
.make_tensor();
8389
}
8490

8591
private:

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
import os
43
import platform
@@ -17,7 +16,6 @@
1716
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1817

1918
_thread_local = threading.local()
20-
logger = logging.getLogger(__name__)
2119

2220

2321
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
@@ -61,8 +59,9 @@ def get_allreduce_mnnvl_workspace(
6159
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
6260
{})
6361
# Support topology split
64-
comm = mpi_comm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
65-
mapping.tp_rank)
62+
comm = mpi_comm().Split(
63+
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
64+
mapping.tp_rank)
6665
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
6766

6867
allreduce_mnnvl_workspaces = getattr(
@@ -82,7 +81,7 @@ def get_allreduce_mnnvl_workspace(
8281
mapping.tp_rank,
8382
# Split the communicator according to the topology
8483
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
85-
torch.device("cuda", mapping.local_rank),
84+
mapping.local_rank,
8685
True, # mnNvlink
8786
)
8887

@@ -463,12 +462,7 @@ def __init__(self,
463462
# Initialize MNNVL AllReduce if needed
464463
if self.strategy in (AllReduceStrategy.AUTO,
465464
AllReduceStrategy.MNNVL):
466-
if self.mapping.tp_size != self.mapping.world_size:
467-
logger.debug(
468-
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
469-
f"!= world_size:{self.mapping.world_size}")
470-
self.mnnvl_allreduce = None
471-
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
465+
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
472466
try:
473467
self.mnnvl_allreduce = MNNVLAllReduce(
474468
self.mapping, dtype) if dtype else None

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
775775
mlp_tp_size = math.gcd(
776776
tp,
777777
self.mapping.gpus_per_node,
778-
) # Avoid costly inter-node TP when MNNVL is not supported
778+
) # Avoid costly inter-node TP
779779
else:
780780
mlp_tp_size = tp
781781
return mlp_tp_size

0 commit comments

Comments
 (0)