Skip to content

Commit 5be2697

Browse files
committed
Address review comments.
1 parent 92cbd48 commit 5be2697

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,18 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt
5353
<< "nranks must be between 2 and 64, got " << nranks;
5454
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
5555
<< "rank must be between 0 and nranks-1, got " << rank;
56-
TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) ||
56+
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
57+
epsilon.has_value()) ||
5758
!rmsnorm_fusion)
58-
<< "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true";
59+
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
60+
"true";
5961

6062
if (rmsnorm_fusion) {
63+
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
64+
residual_in.value().size(1) == token_dim)
65+
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
66+
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
67+
<< ")";
6168
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
6269
residual_out.value().size(1) == token_dim)
6370
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)

flashinfer/comm/mnnvl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,9 @@ def __del__(self):
785785
if not hasattr(self, "is_multi_node"):
786786
return
787787

788+
if hasattr(self, "_ipc_socket"):
789+
self._ipc_socket.close()
790+
788791
# Skip cleanup during Python finalization to avoid segfaults
789792
# Especially cause the CUDA context could be destroyed at this point.
790793
if sys.is_finalizing():
@@ -951,7 +954,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
951954
cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)
952955
)
953956

954-
# Export local handle to fabric handle
957+
# Export local handle to fabric handle or FD
955958
local_shareable_uc_handle = checkCudaErrors(
956959
cuda.cuMemExportToShareableHandle(
957960
self.uc_handles[self.group_rank],
@@ -990,6 +993,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
990993
self._shareable_handle_type,
991994
)
992995
)
996+
if (
997+
self._shareable_handle_type
998+
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
999+
):
1000+
# Close FD after import
1001+
os.close(all_shareable_uc_handles[p])
9931002

9941003
# Initialize multicasting
9951004
if self.group_rank == 0:
@@ -1038,7 +1047,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
10381047
self._shareable_handle_type,
10391048
)
10401049
)
1041-
1050+
if (
1051+
self._shareable_handle_type
1052+
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
1053+
):
1054+
# Close FD after import
1055+
os.close(shareable_mc_handle)
10421056
# Add device to multicast
10431057
checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx))
10441058

include/flashinfer/utils.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cuda_fp8.h>
2222
#include <cuda_runtime.h>
2323

24+
#include <atomic>
2425
#include <cstdint>
2526
#include <iostream>
2627
#include <type_traits>
@@ -289,16 +290,20 @@ inline std::pair<int, int> GetCudaComputeCapability() {
289290
return std::make_pair(major, minor);
290291
}
291292

293+
// This function is thread-safe and cached the sm_count.
294+
// But it will only check the current CUDA device, thus assuming each process handles single GPU.
292295
inline int GetCudaMultiProcessorCount() {
293-
static int sm_count = 0;
294-
if (sm_count == 0) {
296+
static std::atomic<int> sm_count{0};
297+
int cached = sm_count.load(std::memory_order_relaxed);
298+
if (cached == 0) {
295299
int device_id;
296300
cudaGetDevice(&device_id);
297301
cudaDeviceProp device_prop;
298302
cudaGetDeviceProperties(&device_prop, device_id);
299-
sm_count = device_prop.multiProcessorCount;
303+
cached = device_prop.multiProcessorCount;
304+
sm_count.store(cached, std::memory_order_relaxed);
300305
}
301-
return sm_count;
306+
return cached;
302307
}
303308

304309
template <typename T>

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Check torch version:
22
import traceback
3-
from typing import Tuple
3+
from typing import Tuple, Optional
44

55
import pytest
66
import torch
@@ -277,7 +277,7 @@ def run_mnnvl_ar_full(
277277
fusion: bool,
278278
dtype: torch.dtype,
279279
hidden_size: int,
280-
legacy_explicit_workspace_bytes: int = None,
280+
legacy_explicit_workspace_bytes: Optional[int] = None,
281281
legacy_api: bool = False,
282282
):
283283
"""Core test logic for MNNVL AllReduce operations.

0 commit comments

Comments
 (0)