From b7c3dfda22e3ebe8ef8b15658bffcac4e46f946d Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:59:05 +0800 Subject: [PATCH 1/2] add wideep to multigpu pre-merge Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 59f2d616335..010d3bb012f 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -33,6 +33,8 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] - condition: ranges: system_gpu_count: From cc175d848a14142643fa984dd823f35c2a008115 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Tue, 26 Aug 2025 21:36:42 +0800 Subject: [PATCH 2/2] fix possible race conditions Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../kernels/fusedMoeCommKernels.cu | 19 +++++++++++++++++-- .../moeLoadBalance/moeLoadBalanceKernels.cu | 4 ++-- cpp/tensorrt_llm/kernels/moePrepareKernels.cu | 4 ++++ tensorrt_llm/_mnnvl_utils.py | 1 + 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu index b04f4280a93..18c4a58471e 100644 --- a/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu +++ b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu @@ -514,6 +514,21 @@ __device__ __forceinline__ void startWorkspaceS2G( cp_async_bulk_commit_group(); } +__device__ __forceinline__ void startWorkspaceS2GReg( + uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId) +{ + int copyInt4Count = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int4); + int4* sharedMemoryInt4 = reinterpret_cast(sharedMemoryBase); + uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t); + int4* fifoPtrInt4 = reinterpret_cast(fifoPtr); +#pragma unroll 4 + for (int i = laneId; i < copyInt4Count; i += WARP_SIZE) + { + fifoPtrInt4[i] = sharedMemoryInt4[i]; + } + __syncwarp(); +} + __device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry, int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId) { @@ -761,10 +776,10 @@ public: FusedMoeProto::protoPack( mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); - tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase, + tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2GReg(getFifoEntryPtr(), mShmemBase, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); - tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); + // tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); nextToken(); } diff --git a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu index 6f67d45ed09..ace4f135b63 100644 --- a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu +++ b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu @@ -603,7 +603,7 @@ void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal) bool ready = false; do { - auto loaded = signal->stepAndOwner; + auto loaded = __atomic_load_n(&signal->stepAndOwner, __ATOMIC_ACQUIRE); ready = getOwnerDevice(loaded) == MoeLoadBalanceSingleLayerSignal::kCPU; } while (!ready); std::atomic_thread_fence(std::memory_order_acquire); @@ -619,7 +619,7 @@ void moeSetSignalForGpuStageHost(MoeLoadBalanceSingleLayerSignal* signal, int64_ { value |= MoeLoadBalanceSingleLayerSignal::kSkipStep; } - signal->stepAndOwner = value; + __atomic_store_n(&signal->stepAndOwner, value, __ATOMIC_RELEASE); } } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/moePrepareKernels.cu b/cpp/tensorrt_llm/kernels/moePrepareKernels.cu index aea271dab58..3ad90c49c55 100644 --- a/cpp/tensorrt_llm/kernels/moePrepareKernels.cu +++ b/cpp/tensorrt_llm/kernels/moePrepareKernels.cu @@ -60,6 +60,10 @@ public: __forceinline__ __device__ void releaseValue(uint64_t value, int index) { // Avoid block on 0 + while (fifoConnInfo->values[index] != 0) + { + // loop wait until value reset + } fifoConnInfo->values[index] = value + 1; } diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index d30b7316c39..1b8aef36142 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -369,6 +369,7 @@ def get_moe_workspaces(mapping: Mapping): torch.ops.trtllm.moe_initialize_workspace( MnnvlMoe.moe_workspace_tensor, mapping.tp_rank, mapping.tp_size ) + torch.cuda.synchronize() MnnvlMoe.moe_workspace.comm.barrier() return MnnvlMoe.moe_workspace_tensor