Skip to content

Commit a7b677d

Browse files
committed
rework how nvlink is checked
Signed-off-by: Ludwig Schneider <[email protected]>
1 parent 2827024 commit a7b677d

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,19 @@ void AllreducePlugin::setGroupTopology() noexcept
712712
if (isMultiNode)
713713
{
714714
TLLM_LOG_INFO("Found inter-node TP group for rank %d", rank);
715-
// For MNNVL (Multi-Node NVLink), we need to check P2P/NVLINK even in multi-node
716-
// For other strategies, multi-node means no P2P/NVLINK support
717-
if (mStrategy != AllReduceStrategyType::MNNVL)
715+
// Strategies that don't support multi-node P2P/NVLINK
716+
// MIN_LATENCY, ONESHOT, TWOSHOT, LOWPRECISION are designed for intra-node only
717+
if (mStrategy == AllReduceStrategyType::MIN_LATENCY || mStrategy == AllReduceStrategyType::ONESHOT
718+
|| mStrategy == AllReduceStrategyType::TWOSHOT || mStrategy == AllReduceStrategyType::LOWPRECISION)
718719
{
719720
mIsP2PSupported = false;
720721
mIsNVLINKSupported = false;
722+
TLLM_LOG_INFO("Strategy %s does not support multi-node, setting P2P/NVLINK to false for rank %d",
723+
tensorrt_llm::kernels::toString(mStrategy).c_str(), rank);
721724
return;
722725
}
723-
TLLM_LOG_INFO("MNNVL strategy detected - checking multi-node P2P/NVLINK for rank %d", rank);
726+
// Other strategies (like MNNVL) will continue to check P2P/NVLINK in multi-node
727+
TLLM_LOG_INFO("Multi-node strategy detected - checking P2P/NVLINK across nodes for rank %d", rank);
724728
}
725729
else
726730
{
@@ -732,10 +736,9 @@ void AllreducePlugin::setGroupTopology() noexcept
732736
mIsP2PSupported = true;
733737
mIsNVLINKSupported = true;
734738

735-
// For MNNVL in multi-node, check all devices in the full group
739+
// For multi-node with supported strategies, check all devices in the full group
736740
// For intra-node, check only localGroup
737-
std::set<int> const& devicesToCheck
738-
= (isMultiNode && mStrategy == AllReduceStrategyType::MNNVL) ? mGroup : localGroup;
741+
std::set<int> const& devicesToCheck = isMultiNode ? mGroup : localGroup;
739742

740743
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,
741744
// and use nvml to determine whether there are nvlink links between ranks.

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,15 +1009,19 @@ class AllreduceOp
10091009
if (isMultiNode)
10101010
{
10111011
TLLM_LOG_INFO("Found inter-node TP group for rank %d", rank);
1012-
// For MNNVL (Multi-Node NVLink), we need to check P2P/NVLINK even in multi-node
1013-
// For other strategies, multi-node means no P2P/NVLINK support
1014-
if (mStrategy != AllReduceStrategyType::MNNVL)
1012+
// Strategies that don't support multi-node P2P/NVLINK
1013+
// MIN_LATENCY, ONESHOT, TWOSHOT, LOWPRECISION are designed for intra-node only
1014+
if (mStrategy == AllReduceStrategyType::MIN_LATENCY || mStrategy == AllReduceStrategyType::ONESHOT
1015+
|| mStrategy == AllReduceStrategyType::TWOSHOT || mStrategy == AllReduceStrategyType::LOWPRECISION)
10151016
{
10161017
mIsP2PSupported = false;
10171018
mIsNVLINKSupported = false;
1019+
TLLM_LOG_INFO("Strategy %s does not support multi-node, setting P2P/NVLINK to false for rank %d",
1020+
tensorrt_llm::kernels::toString(mStrategy).c_str(), rank);
10181021
return;
10191022
}
1020-
TLLM_LOG_INFO("MNNVL strategy detected - checking multi-node P2P/NVLINK for rank %d", rank);
1023+
// Other strategies (like MNNVL) will continue to check P2P/NVLINK in multi-node
1024+
TLLM_LOG_INFO("Multi-node strategy detected - checking P2P/NVLINK across nodes for rank %d", rank);
10211025
}
10221026
else
10231027
{
@@ -1028,10 +1032,9 @@ class AllreduceOp
10281032
mIsP2PSupported = true;
10291033
mIsNVLINKSupported = true;
10301034

1031-
// For MNNVL in multi-node, check all devices in the full group
1035+
// For multi-node with supported strategies, check all devices in the full group
10321036
// For intra-node, check only local_group
1033-
std::set<int> const& devices_to_check
1034-
= (isMultiNode && mStrategy == AllReduceStrategyType::MNNVL) ? mGroup : local_group;
1037+
std::set<int> const& devices_to_check = isMultiNode ? mGroup : local_group;
10351038

10361039
// TODO(ytong): Should we provide group topology info instead of querying it here?
10371040
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,

0 commit comments

Comments
 (0)