Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,21 @@ __device__ __forceinline__ void startWorkspaceS2G(
cp_async_bulk_commit_group();
}

__device__ __forceinline__ void startWorkspaceS2GReg(
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId)
{
int copyInt4Count = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int4);
int4* sharedMemoryInt4 = reinterpret_cast<int4*>(sharedMemoryBase);
uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t);
int4* fifoPtrInt4 = reinterpret_cast<int4*>(fifoPtr);
#pragma unroll 4
for (int i = laneId; i < copyInt4Count; i += WARP_SIZE)
{
fifoPtrInt4[i] = sharedMemoryInt4[i];
}
__syncwarp();
}

__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry,
int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId)
{
Expand Down Expand Up @@ -761,10 +776,10 @@ public:
FusedMoeProto::protoPack(
mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);

tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase,
tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2GReg(getFifoEntryPtr(), mShmemBase,
mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId);

tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();
// tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead();

nextToken();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal)
bool ready = false;
do
{
auto loaded = signal->stepAndOwner;
auto loaded = __atomic_load_n(&signal->stepAndOwner, __ATOMIC_ACQUIRE);
ready = getOwnerDevice(loaded) == MoeLoadBalanceSingleLayerSignal::kCPU;
} while (!ready);
std::atomic_thread_fence(std::memory_order_acquire);
Expand All @@ -619,7 +619,7 @@ void moeSetSignalForGpuStageHost(MoeLoadBalanceSingleLayerSignal* signal, int64_
{
value |= MoeLoadBalanceSingleLayerSignal::kSkipStep;
}
signal->stepAndOwner = value;
__atomic_store_n(&signal->stepAndOwner, value, __ATOMIC_RELEASE);
}

} // namespace kernels
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/moePrepareKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public:
__forceinline__ __device__ void releaseValue(uint64_t value, int index)
{
// Avoid block on 0
while (fifoConnInfo->values[index] != 0)
{
// loop wait until value reset
}
fifoConnInfo->values[index] = value + 1;
}

Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_mnnvl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def get_moe_workspaces(mapping: Mapping):
torch.ops.trtllm.moe_initialize_workspace(
MnnvlMoe.moe_workspace_tensor, mapping.tp_rank, mapping.tp_size
)
torch.cuda.synchronize()
MnnvlMoe.moe_workspace.comm.barrier()
return MnnvlMoe.moe_workspace_tensor

Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
- condition:
ranges:
system_gpu_count:
Expand Down