Skip to content

Commit d52b039

Browse files
authored
Merge branch 'main' into fix-rpc
Signed-off-by: Yan Chunwei <[email protected]>
2 parents b492002 + fac47e2 commit d52b039

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1970
-291
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,9 @@ class CacheSender::Impl
291291
mSelfState.setCommState(std::move(commState));
292292
}
293293

294-
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId) const
294+
[[nodiscard]] size_t getCounterpartsCount(LlmRequest::RequestIdType requestId)
295295
{
296+
std::unique_lock<std::mutex> lock(mMtxForMap);
296297
auto it = mRequestToSession.find(requestId);
297298
TLLM_CHECK(it != mRequestToSession.end());
298299
return it->second.getConnections().size();

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717

1818
#pragma once
19+
#include "tensorrt_llm/common/cudaUtils.h"
1920
#include <cstdint>
21+
#include <cuda_runtime.h>
2022
#include <optional>
2123
#include <string>
2224

@@ -55,6 +57,26 @@ int getEnvMmhaKernelBlockSize();
5557
// Whether PDL is enabled.
5658
bool getEnvEnablePDL();
5759

60+
template <typename KernelFn, typename... Args>
61+
inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 grid, dim3 block, size_t dynamicShmSize,
62+
cudaStream_t stream, Args&&... args)
63+
{
64+
TLLM_LOG_DEBUG("Enable PDL in %s", name);
65+
cudaLaunchConfig_t kernelConfig;
66+
kernelConfig.gridDim = grid;
67+
kernelConfig.blockDim = block;
68+
kernelConfig.dynamicSmemBytes = dynamicShmSize;
69+
kernelConfig.stream = stream;
70+
71+
cudaLaunchAttribute attrs[1];
72+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
73+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
74+
kernelConfig.attrs = attrs;
75+
kernelConfig.numAttrs = 1;
76+
77+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...));
78+
}
79+
5880
bool getEnvUseUCXKvCache();
5981

6082
bool getEnvUseMPIKvCache();

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace tensorrt_llm
2727
namespace kernels
2828
{
2929

30+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
31+
3032
// Quantize a contiguous shared-memory buffer containing elements of DType into NVFP4 with per-16-element FP8 scales.
3133
// Output layout (repeated per 16-element group per lane), followed by one global scale float:
3234
// [WARP_SIZE * 8 bytes packed e2m1 values] [WARP_SIZE * 1 byte E4M3 per-group scales] ... [global_scale (4 bytes)]
@@ -1069,6 +1071,10 @@ public:
10691071

10701072
int sendIndex = mPairInfo.channel;
10711073
uint32_t phaseParity = 0;
1074+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1075+
cudaGridDependencySynchronize();
1076+
cudaTriggerProgrammaticLaunchCompletion();
1077+
#endif
10721078
for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount)
10731079
{
10741080
int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex];
@@ -1140,6 +1146,10 @@ public:
11401146
int recvIndex = mPairInfo.channel;
11411147
uint32_t phaseParity = 0;
11421148
bool needRelease = false;
1149+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1150+
cudaGridDependencySynchronize();
1151+
cudaTriggerProgrammaticLaunchCompletion();
1152+
#endif
11431153
for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount)
11441154
{
11451155
int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex];
@@ -1459,7 +1469,8 @@ void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cu
14591469

14601470
dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta);
14611471
dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta);
1462-
kernelFn<<<grid, block, totalDynamicShmSize, stream>>>(params, workspace, hasBasicFields);
1472+
launchWithPdlWhenEnabled(
1473+
"moeAllToAll", kernelFn, grid, block, totalDynamicShmSize, stream, params, workspace, hasBasicFields);
14631474
TLLM_CUDA_CHECK(cudaGetLastError());
14641475
}
14651476

cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cuda_runtime_api.h>
2121

2222
#include "tensorrt_llm/common/cudaUtils.h"
23+
#include "tensorrt_llm/common/envUtils.h"
2324
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
2425

2526
namespace tensorrt_llm

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cub/cub.cuh>
2020

2121
#include "tensorrt_llm/common/cudaUtils.h"
22+
#include "tensorrt_llm/common/envUtils.h"
2223
#include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h"
2324

2425
namespace cg = cooperative_groups;
@@ -28,6 +29,8 @@ namespace tensorrt_llm
2829
namespace kernels
2930
{
3031

32+
using tensorrt_llm::common::launchWithPdlWhenEnabled;
33+
3134
int getOwnerDevice(unsigned long long int stepAndOwner)
3235
{
3336
return static_cast<int>(stepAndOwner & MoeLoadBalanceSingleLayerSignal::kDevice);
@@ -71,6 +74,11 @@ __device__ __forceinline__ void moeWaitSignalForGpuStageFunc(MoeLoadBalanceSingl
7174

7275
__global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
7376
{
77+
78+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
79+
cudaGridDependencySynchronize();
80+
cudaTriggerProgrammaticLaunchCompletion();
81+
#endif
7482
if (threadIdx.x == 0 and blockIdx.x == 0)
7583
{
7684
moeWaitSignalForGpuStageFunc(signal, enabled);
@@ -79,6 +87,11 @@ __global__ void moeWaitSignalForGpuStageKernel(MoeLoadBalanceSingleLayerSignal*
7987

8088
__global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* signal)
8189
{
90+
91+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
92+
cudaGridDependencySynchronize();
93+
cudaTriggerProgrammaticLaunchCompletion();
94+
#endif
8295
if (threadIdx.x == 0 and blockIdx.x == 0)
8396
{
8497
unsigned long long int loaded = signal->stepAndOwner;
@@ -91,7 +104,8 @@ __global__ void moeSetSignalForCpuStageKernel(MoeLoadBalanceSingleLayerSignal* s
91104

92105
void moeWaitSignalForGpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, int* enabled, cudaStream_t stream)
93106
{
94-
moeWaitSignalForGpuStageKernel<<<1, 1, 0, stream>>>(signal, enabled);
107+
launchWithPdlWhenEnabled(
108+
"moeWaitSignalForGpuStage", moeWaitSignalForGpuStageKernel, 1, 1, 0, stream, signal, enabled);
95109
}
96110

97111
void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, int* enabled)
@@ -119,7 +133,7 @@ void moeWaitSignalForGpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal, in
119133

120134
void moeSetSignalForCpuStageDevice(MoeLoadBalanceSingleLayerSignal* signal, cudaStream_t stream)
121135
{
122-
moeSetSignalForCpuStageKernel<<<1, 1, 0, stream>>>(signal);
136+
launchWithPdlWhenEnabled("moeSetSignalForCpuStage", moeSetSignalForCpuStageKernel, 1, 1, 0, stream, signal);
123137
}
124138

125139
void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
@@ -138,6 +152,10 @@ __global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int*
138152
TYPE oldExpertTokenCount = {0};
139153
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
140154
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);
155+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
156+
cudaGridDependencySynchronize();
157+
cudaTriggerProgrammaticLaunchCompletion();
158+
#endif
141159
typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount;
142160
}
143161

@@ -149,6 +167,10 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en
149167
return;
150168
}
151169
TYPE oldExpertTokenCount = {0};
170+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
171+
cudaGridDependencySynchronize();
172+
cudaTriggerProgrammaticLaunchCompletion();
173+
#endif
152174
if (blockIdx.x > 0)
153175
{
154176
int* oldExpertTokenCountPtr = expertTokenCount + metaInfo.expertCount * (blockIdx.x - 1);
@@ -177,6 +199,10 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertToke
177199
sharedExpertCount[i] = 0;
178200
}
179201
__syncthreads();
202+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
203+
cudaGridDependencySynchronize();
204+
cudaTriggerProgrammaticLaunchCompletion();
205+
#endif
180206
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < totalEltCount; idx += gridDim.x * blockDim.x)
181207
{
182208
int expertId = gatheredRawExpertIds[idx];
@@ -200,6 +226,10 @@ __global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadB
200226
return;
201227
}
202228
int expertIdx = blockIdx.x * blockDim.x + threadIdx.x;
229+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
230+
cudaGridDependencySynchronize();
231+
cudaTriggerProgrammaticLaunchCompletion();
232+
#endif
203233
int expertTokenCount = expertTokenCountPtr[expertIdx];
204234
float* loadFactor = statisticInfo.expertLoadFactor;
205235
loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount;
@@ -232,6 +262,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
232262
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&expertTokenCount)};
233263
TLLM_CHECK_WITH_INFO(
234264
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
265+
// TODO: add PDL support with cooperative launch
235266
TLLM_CUDA_CHECK(cudaLaunchCooperativeKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
236267
}
237268

@@ -245,7 +276,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
245276
blockCount = smCount;
246277
}
247278
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
248-
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
279+
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
249280
metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds);
250281
}
251282

@@ -254,7 +285,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
254285
// only last stage need update load factor.
255286
int threadCount = 128;
256287
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
257-
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
288+
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream,
258289
metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled);
259290
}
260291
}
@@ -282,11 +313,10 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
282313
}
283314
dim3 gridDim(1);
284315
dim3 blockDim(threadCount);
285-
void* args[]
286-
= {&metaInfo, static_cast<void*>(const_cast<int**>(&enabled)), static_cast<void*>(&localExpertTokenCount)};
287316
TLLM_CHECK_WITH_INFO(
288317
threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount);
289-
TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream));
318+
launchWithPdlWhenEnabled(
319+
"zeroExpertTokenCount", kernelFunc, gridDim, blockDim, 0, stream, metaInfo, enabled, localExpertTokenCount);
290320
}
291321

292322
{
@@ -299,7 +329,7 @@ void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int nu
299329
blockCount = smCount;
300330
}
301331
int sharedMemorySize = metaInfo.expertCount * sizeof(int);
302-
statisticKernel<<<blockCount, threadCount, sharedMemorySize, stream>>>(
332+
launchWithPdlWhenEnabled("statisticKernel", statisticKernel, blockCount, threadCount, sharedMemorySize, stream,
303333
metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds);
304334
}
305335
}
@@ -309,8 +339,8 @@ void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBala
309339
{
310340
int threadCount = 128;
311341
int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount;
312-
updateLoadFactorKernel<<<blockCount, threadCount, 0, stream>>>(
313-
metaInfo, statisticInfo, globalExpertTokenCount, enabled);
342+
launchWithPdlWhenEnabled("updateLoadFactor", updateLoadFactorKernel, blockCount, threadCount, 0, stream, metaInfo,
343+
statisticInfo, globalExpertTokenCount, enabled);
314344
}
315345

316346
template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
@@ -320,13 +350,18 @@ __global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo
320350
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
321351
int expertIds[ITEM_PER_THREAD];
322352
int slotIds[ITEM_PER_THREAD];
353+
354+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
355+
cudaGridDependencySynchronize();
356+
cudaTriggerProgrammaticLaunchCompletion();
357+
#endif
358+
323359
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
324360
{
325361
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
326362
}
327363

328364
int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;
329-
330365
for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
331366
{
332367
int tokenIdxBase = blockOffset + threadIdx.x;
@@ -379,6 +414,12 @@ __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacem
379414

380415
__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
381416
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
417+
418+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
419+
cudaGridDependencySynchronize();
420+
cudaTriggerProgrammaticLaunchCompletion();
421+
#endif
422+
382423
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
383424
{
384425
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
@@ -484,6 +525,11 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
484525
__shared__ int sharedSortedExpertId[THREAD_COUNT * ITEM_PER_THREAD];
485526
__shared__ int sharedExpertStartThread[MAX_EXPERT_COUNT];
486527

528+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
529+
cudaGridDependencySynchronize();
530+
cudaTriggerProgrammaticLaunchCompletion();
531+
#endif
532+
487533
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
488534
{
489535
sharedExpertTokenCount[expertIdx] = 0;
@@ -500,7 +546,6 @@ __global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePl
500546
__syncthreads();
501547

502548
int expertIds[ITEM_PER_THREAD];
503-
504549
for (int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD; blockOffset < tokenCount * metaInfo.topK;
505550
blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
506551
{
@@ -586,14 +631,15 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
586631
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
587632
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
588633
{
634+
auto* kernelFn = moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>;
589635
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
590-
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
591-
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
592-
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
636+
launchWithPdlWhenEnabled("moeComputeRouteNoRedundant", kernelFn, blockCount, kThreadCount, dynamicShmSize,
637+
stream, metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
593638
}
594639
else
595640
{
596-
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
641+
auto* kernelFn = moeComputeRouteKernel<1024, kThreadCount, kEltPerThread>;
642+
launchWithPdlWhenEnabled("moeComputeRoute", kernelFn, blockCount, kThreadCount, dynamicShmSize, stream,
597643
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
598644
}
599645
}

0 commit comments

Comments
 (0)