1919#include < cub/cub.cuh>
2020
2121#include " tensorrt_llm/common/cudaUtils.h"
22+ #include " tensorrt_llm/common/envUtils.h"
2223#include " tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
2324
2425namespace cg = cooperative_groups;
@@ -28,6 +29,8 @@ namespace tensorrt_llm
2829namespace kernels
2930{
3031
32+ using tensorrt_llm::common::launchWithPdlWhenEnabled;
33+
3134int getOwnerDevice (unsigned long long int stepAndOwner)
3235{
3336 return static_cast <int >(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice );
@@ -71,6 +74,11 @@ __device__ __forceinline__ void moeWaitSignalForGpuStageFunc(MoeLoadBalanceSingl
7174
7275__global__ void moeWaitSignalForGpuStageKernel (MoeLoadBalanceSingleLayerSignal* signal, int * enabled)
7376{
77+
78+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
79+ cudaGridDependencySynchronize ();
80+ cudaTriggerProgrammaticLaunchCompletion ();
81+ #endif
7482 if (threadIdx .x == 0 and blockIdx .x == 0 )
7583 {
7684 moeWaitSignalForGpuStageFunc (signal, enabled);
@@ -79,6 +87,11 @@ __global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal*
7987
8088__global__ void moeSetSignalForCpuStageKernel (MoeLoadBalanceSingleLayerSignal* signal)
8189{
90+
91+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
92+ cudaGridDependencySynchronize ();
93+ cudaTriggerProgrammaticLaunchCompletion ();
94+ #endif
8295 if (threadIdx .x == 0 and blockIdx .x == 0 )
8396 {
8497 unsigned long long int loaded = signal->stepAndOwner ;
@@ -91,7 +104,8 @@ __global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* s
91104
92105void moeWaitSignalForGpuStageDevice (MoeLoadBalanceSingleLayerSignal* signal, int * enabled, cudaStream_t stream)
93106{
94- moeWaitSignalForGpuStageKernel<<<1 , 1 , 0 , stream>>> (signal, enabled);
107+ launchWithPdlWhenEnabled (
108+ " moeWaitSignalForGpuStage" , moeWaitSignalForGpuStageKernel, 1 , 1 , 0 , stream, signal, enabled);
95109}
96110
97111void moeWaitSignalForGpuStageForTest (MoeLoadBalanceSingleLayerSignal* signal, int * enabled)
@@ -119,7 +133,7 @@ void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, in
119133
120134void moeSetSignalForCpuStageDevice (MoeLoadBalanceSingleLayerSignal* signal, cudaStream_t stream)
121135{
122- moeSetSignalForCpuStageKernel<<< 1 , 1 , 0 , stream>>> ( signal);
136+ launchWithPdlWhenEnabled ( " moeSetSignalForCpuStage " , moeSetSignalForCpuStageKernel, 1 , 1 , 0 , stream, signal);
123137}
124138
125139void moeSetSignalForCpuStageForTest (MoeLoadBalanceSingleLayerSignal* signal)
@@ -138,6 +152,10 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
138152 TYPE oldExpertTokenCount = {0 };
139153 int * expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx .x ;
140154 TYPE* typedExpertTokenCountPtr = reinterpret_cast <TYPE*>(expertTokenCountPtr);
155+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
156+ cudaGridDependencySynchronize ();
157+ cudaTriggerProgrammaticLaunchCompletion ();
158+ #endif
141159 typedExpertTokenCountPtr[threadIdx .x ] = oldExpertTokenCount;
142160}
143161
@@ -149,6 +167,10 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
149167 return ;
150168 }
151169 TYPE oldExpertTokenCount = {0 };
170+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
171+ cudaGridDependencySynchronize ();
172+ cudaTriggerProgrammaticLaunchCompletion ();
173+ #endif
152174 if (blockIdx .x > 0 )
153175 {
154176 int * oldExpertTokenCountPtr = expertTokenCount + metaInfo.expertCount * (blockIdx .x - 1 );
@@ -177,6 +199,10 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
177199 sharedExpertCount[i] = 0 ;
178200 }
179201 __syncthreads ();
202+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
203+ cudaGridDependencySynchronize ();
204+ cudaTriggerProgrammaticLaunchCompletion ();
205+ #endif
180206 for (int idx = threadIdx .x + blockIdx .x * blockDim .x ; idx < totalEltCount; idx += gridDim .x * blockDim .x )
181207 {
182208 int expertId = gatheredRawExpertIds[idx];
@@ -200,6 +226,10 @@ __global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadB
200226 return ;
201227 }
202228 int expertIdx = blockIdx .x * blockDim .x + threadIdx .x ;
229+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
230+ cudaGridDependencySynchronize ();
231+ cudaTriggerProgrammaticLaunchCompletion ();
232+ #endif
203233 int expertTokenCount = expertTokenCountPtr[expertIdx];
204234 float * loadFactor = statisticInfo.expertLoadFactor ;
205235 loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
@@ -232,6 +262,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
232262 = {&metaInfo, static_cast <void *>(const_cast <int **>(&enabled)), static_cast <void *>(&expertTokenCount)};
233263 TLLM_CHECK_WITH_INFO (
234264 threadCount <= 1024 , " expertCount=%d is too large and not supported now." , metaInfo.expertCount );
265+ // TODO: add PDL support with cooperative launch
235266 TLLM_CUDA_CHECK (cudaLaunchCooperativeKernel (kernelFunc, gridDim , blockDim , &args[0 ], 0 , stream));
236267 }
237268
@@ -245,7 +276,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
245276 blockCount = smCount;
246277 }
247278 int sharedMemorySize = metaInfo.expertCount * sizeof (int );
248- statisticKernel<<< blockCount, threadCount, sharedMemorySize, stream>>> (
279+ launchWithPdlWhenEnabled ( " statisticKernel" , statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
249280 metaInfo, statisticInfo.expertTokenCount , totalEltCount, enabled, gatheredRawExpertIds);
250281 }
251282
@@ -254,7 +285,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
254285 // only last stage need update load factor.
255286 int threadCount = 128 ;
256287 int blockCount = (metaInfo.expertCount + threadCount - 1 ) / threadCount;
257- updateLoadFactorKernel<<< blockCount, threadCount, 0 , stream>>> (
288+ launchWithPdlWhenEnabled ( " updateLoadFactor " , updateLoadFactorKernel, blockCount, threadCount, 0 , stream,
258289 metaInfo, statisticInfo, statisticInfo.expertTokenCount , enabled);
259290 }
260291}
@@ -282,11 +313,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
282313 }
283314 dim3 gridDim (1 );
284315 dim3 blockDim (threadCount);
285- void * args[]
286- = {&metaInfo, static_cast <void *>(const_cast <int **>(&enabled)), static_cast <void *>(&localExpertTokenCount)};
287316 TLLM_CHECK_WITH_INFO (
288317 threadCount <= 1024 , " expertCount=%d is too large and not supported now." , metaInfo.expertCount );
289- TLLM_CUDA_CHECK (cudaLaunchKernel (kernelFunc, gridDim , blockDim , &args[0 ], 0 , stream));
318+ launchWithPdlWhenEnabled (
319+ " zeroExpertTokenCount" , kernelFunc, gridDim , blockDim , 0 , stream, metaInfo, enabled, localExpertTokenCount);
290320 }
291321
292322 {
@@ -299,7 +329,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
299329 blockCount = smCount;
300330 }
301331 int sharedMemorySize = metaInfo.expertCount * sizeof (int );
302- statisticKernel<<< blockCount, threadCount, sharedMemorySize, stream>>> (
332+ launchWithPdlWhenEnabled ( " statisticKernel" , statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
303333 metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
304334 }
305335}
@@ -309,8 +339,8 @@ void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBala
309339{
310340 int threadCount = 128 ;
311341 int blockCount = (metaInfo.expertCount + threadCount - 1 ) / threadCount;
312- updateLoadFactorKernel<<< blockCount, threadCount, 0 , stream>>> (
313- metaInfo, statisticInfo, globalExpertTokenCount, enabled);
342+ launchWithPdlWhenEnabled ( " updateLoadFactor " , updateLoadFactorKernel, blockCount, threadCount, 0 , stream, metaInfo,
343+ statisticInfo, globalExpertTokenCount, enabled);
314344}
315345
316346template <int MAX_EXPERT_COUNT = 1024 , int THREAD_COUNT = 256 , int ITEM_PER_THREAD = 4 >
@@ -320,13 +350,18 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
320350 extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
321351 int expertIds[ITEM_PER_THREAD];
322352 int slotIds[ITEM_PER_THREAD];
353+
354+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
355+ cudaGridDependencySynchronize ();
356+ cudaTriggerProgrammaticLaunchCompletion ();
357+ #endif
358+
323359 for (int slotId = threadIdx .x ; slotId < metaInfo.epSize * metaInfo.slotCountPerRank ; slotId += THREAD_COUNT)
324360 {
325361 sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds [slotId];
326362 }
327363
328364 int blockOffset = blockIdx .x * THREAD_COUNT * ITEM_PER_THREAD;
329-
330365 for (; blockOffset < tokenCount * metaInfo.topK ; blockOffset += gridDim .x * THREAD_COUNT * ITEM_PER_THREAD)
331366 {
332367 int tokenIdxBase = blockOffset + threadIdx .x ;
@@ -379,6 +414,12 @@ __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacem
379414
380415 __shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
381416 __shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
417+
418+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
419+ cudaGridDependencySynchronize ();
420+ cudaTriggerProgrammaticLaunchCompletion ();
421+ #endif
422+
382423 for (int expertIdx = threadIdx .x ; expertIdx < metaInfo.expertCount ; expertIdx += THREAD_COUNT)
383424 {
384425 int replicaCount = placementInfo.expertReplicaCount [expertIdx];
@@ -484,6 +525,11 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
484525 __shared__ int sharedSortedExpertId[THREAD_COUNT * ITEM_PER_THREAD];
485526 __shared__ int sharedExpertStartThread[MAX_EXPERT_COUNT];
486527
528+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
529+ cudaGridDependencySynchronize ();
530+ cudaTriggerProgrammaticLaunchCompletion ();
531+ #endif
532+
487533 for (int expertIdx = threadIdx .x ; expertIdx < metaInfo.expertCount ; expertIdx += THREAD_COUNT)
488534 {
489535 sharedExpertTokenCount[expertIdx] = 0 ;
@@ -500,7 +546,6 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
500546 __syncthreads ();
501547
502548 int expertIds[ITEM_PER_THREAD];
503-
504549 for (int blockOffset = blockIdx .x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK ;
505550 blockOffset += gridDim .x * THREAD_COUNT * ITEM_PER_THREAD)
506551 {
@@ -586,14 +631,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
586631 int dynamicShmSize = sizeof (int16_t ) * metaInfo.epSize * metaInfo.slotCountPerRank ;
587632 if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank )
588633 {
634+ auto * kernelFn = moeComputeRouteNoRedundantKernel<1024 , kThreadCount , kEltPerThread >;
589635 // no redundant expert, so we don't need complex routing, but just assign to the correct solt.
590- moeComputeRouteNoRedundantKernel<1024 , kThreadCount , kEltPerThread >
591- <<<blockCount, kThreadCount , dynamicShmSize, stream>>> (
592- metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
636+ launchWithPdlWhenEnabled (" moeComputeRouteNoRedundant" , kernelFn, blockCount, kThreadCount , dynamicShmSize,
637+ stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
593638 }
594639 else
595640 {
596- moeComputeRouteKernel<1024 , kThreadCount , kEltPerThread ><<<blockCount, kThreadCount , dynamicShmSize, stream>>> (
641+ auto * kernelFn = moeComputeRouteKernel<1024 , kThreadCount , kEltPerThread >;
642+ launchWithPdlWhenEnabled (" moeComputeRoute" , kernelFn, blockCount, kThreadCount , dynamicShmSize, stream,
597643 metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
598644 }
599645}
0 commit comments