Skip to content

Commit ffd8259

Browse files
authored
Merge branch 'NVIDIA:main' into mtp_optimizations_round1
2 parents f9fc02b + 3dfc819 commit ffd8259

File tree

19 files changed

+179
-113
lines changed

19 files changed

+179
-113
lines changed

.github/CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
/tensorrt_llm/commands/bench.py @NVIDIA/trtllm-bench-reviewers
2020
docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
2121

22+
## TensorRT-LLM LLM API
23+
/tensorrt_llm/llmapi @NVIDIA/trt-llm-llmapi-devs
24+
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
2225

2326
# The rule below requires that any PR modifying public APIs must be approved by at least one member
2427
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.

cpp/tensorrt_llm/kernels/moeCommKernels.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -728,21 +728,24 @@ __global__ void moeLocalGatherDevice(MoeEpWorldInfo worldInfo, MoeExpertParallel
728728

729729
int epSize = worldInfo.epSize;
730730
int rankTokenCount = recvRankCountCumSum[epSize - 1];
731-
bool needLoad = laneInTile < expertParallelInfo.topK;
731+
if (laneInTile >= expertParallelInfo.topK)
732+
{
733+
return;
734+
}
732735

733736
for (int index = tileId + blockIdx.x * tileCountPerBlock; index < localMaxTokenCount;
734737
index += tileCountPerBlock * gridDim.x)
735738
{
736739
int localTokenIndice = localGatherIndices[index];
737-
int expertId = needLoad && (index < rankTokenCount)
740+
int expertId = index < rankTokenCount
738741
? gatheredExpertIds[localTokenIndice * expertParallelInfo.topK + laneInTile]
739742
: expertParallelInfo.expertCount;
740-
float scale = needLoad && (index < rankTokenCount)
741-
? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile]
742-
: 0.0f;
743-
if (needLoad)
743+
localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId;
744+
if (gatheredScales)
744745
{
745-
localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId;
746+
float scale = index < rankTokenCount
747+
? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile]
748+
: 0.0f;
746749
localScales[index * expertParallelInfo.topK + laneInTile] = scale;
747750
}
748751
}

cpp/tensorrt_llm/kernels/moePrepareKernels.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,6 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
461461
{
462462
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
463463
*((int4*) (experts)) = *(int4*) (sendExperts + tokenId * topK + groupId * UNIT_SIZE);
464-
*((float4*) (scales)) = *(float4*) (sendScales + tokenId * topK + groupId * UNIT_SIZE);
465464

466465
#pragma unroll
467466
for (int j = 0; j < UNIT_SIZE; j++)
@@ -470,15 +469,18 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
470469
if (expertId / slotCountPerRank != targetRankId)
471470
{
472471
experts[j] = slotCount;
473-
scales[j] = 0.0f;
474472
}
475473
}
476474

477475
int* expertsPtr = (int*) (packPtr) + threadIdx.x * UNIT_SIZE;
478-
float* scaleBasePtr = (float*) (packPtr + SCALE_OFFSET);
479-
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * UNIT_SIZE;
480476
*((int4*) (expertsPtr)) = *((int4*) (experts));
481-
*((float4*) (scalesPtr)) = *((float4*) (scales));
477+
if (sendScales != nullptr)
478+
{
479+
*((float4*) (scales)) = *(float4*) (sendScales + tokenId * topK + groupId * UNIT_SIZE);
480+
float* scaleBasePtr = (float*) (packPtr + SCALE_OFFSET);
481+
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * UNIT_SIZE;
482+
*((float4*) (scalesPtr)) = *((float4*) (scales));
483+
}
482484
}
483485
}
484486
else if (localExpertStatics != nullptr)
@@ -518,18 +520,20 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
518520
{
519521
if (threadIdx.x < packetUnitCount)
520522
{
523+
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
521524
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * UNIT_SIZE;
522-
float* scaleBasePtr = (float*) (packetPtr + SCALE_OFFSET);
523-
float* scalesPtr = scaleBasePtr + threadIdx.x * UNIT_SIZE;
524525
*((int4*) (experts)) = *((int4*) (expertsPtr));
525-
*((float4*) (scales)) = *((float4*) (scalesPtr));
526-
527-
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
528-
529526
int4* dstExpertsPtr = (int4*) (recvExperts + tokenId * topK + groupId * UNIT_SIZE);
530-
float4* dstScalesPtr = (float4*) (recvScales + tokenId * topK + groupId * UNIT_SIZE);
531527
*dstExpertsPtr = *((int4*) (experts));
532-
*dstScalesPtr = *((float4*) (scales));
528+
529+
if (recvScales != nullptr)
530+
{
531+
float* scaleBasePtr = (float*) (packetPtr + SCALE_OFFSET);
532+
float* scalesPtr = scaleBasePtr + threadIdx.x * UNIT_SIZE;
533+
*((float4*) (scales)) = *((float4*) (scalesPtr));
534+
float4* dstScalesPtr = (float4*) (recvScales + tokenId * topK + groupId * UNIT_SIZE);
535+
*dstScalesPtr = *((float4*) (scales));
536+
}
533537
}
534538
}
535539
else if (localExpertStatics != nullptr)

cpp/tensorrt_llm/thop/moeCommOp.cpp

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,13 @@ moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional<torch
8787
}
8888

8989
void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds,
90-
torch::Tensor gatheredScales, torch::Tensor localExpertIds, torch::Tensor localScales, int64_t maxTokenCountPerRank,
91-
int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
90+
c10::optional<torch::Tensor> gatheredScales, torch::Tensor localExpertIds, c10::optional<torch::Tensor> localScales,
91+
int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
9292
{
9393
CHECK_INPUT(recvRankCumSum, torch::kInt32);
9494
CHECK_INPUT(localGatherIndices, torch::kInt32);
9595
CHECK_INPUT(gatheredExpertIds, torch::kInt32);
96-
CHECK_INPUT(gatheredScales, torch::kFloat32);
9796
CHECK_INPUT(localExpertIds, torch::kInt32);
98-
CHECK_INPUT(localScales, torch::kFloat32);
9997

10098
TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0");
10199
TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0");
@@ -107,17 +105,31 @@ void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherInd
107105
TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements");
108106
TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor");
109107
TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor");
110-
TORCH_CHECK(gatheredScales.dim() == 2, "gatheredScales must be a 2D tensor");
111108
TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor");
112-
TORCH_CHECK(localScales.dim() == 2, "localScales must be a 2D tensor");
113109
TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns");
114-
TORCH_CHECK(gatheredScales.size(1) == topK, "gatheredScales must have topK columns");
115110
TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns");
116-
TORCH_CHECK(localScales.size(1) == topK, "localScales must have topK columns");
117111

118112
int localMaxTokenCount = static_cast<int>(localGatherIndices.size(0));
119113
TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows");
120-
TORCH_CHECK(localScales.size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows");
114+
115+
TORCH_CHECK(gatheredScales.has_value() == localScales.has_value(),
116+
"gatheredScales and localScales must be both valid or both invalid");
117+
float const* gatheredScalesPtr = nullptr;
118+
float* localScalesPtr = nullptr;
119+
if (gatheredScales.has_value())
120+
{
121+
CHECK_INPUT(gatheredScales.value(), torch::kFloat32);
122+
CHECK_INPUT(localScales.value(), torch::kFloat32);
123+
124+
TORCH_CHECK(gatheredScales->dim() == 2, "gatheredScales must be a 2D tensor");
125+
TORCH_CHECK(gatheredScales->size(1) == topK, "gatheredScales must have topK columns");
126+
TORCH_CHECK(localScales->dim() == 2, "localScales must be a 2D tensor");
127+
TORCH_CHECK(localScales->size(1) == topK, "localScales must have topK columns");
128+
TORCH_CHECK(localScales->size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows");
129+
130+
gatheredScalesPtr = gatheredScales->data_ptr<float>();
131+
localScalesPtr = localScales->data_ptr<float>();
132+
}
121133

122134
auto stream = at::cuda::getCurrentCUDAStream();
123135

@@ -128,7 +140,7 @@ void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherInd
128140
tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast<int>(epSize), static_cast<int>(epRank)};
129141
tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount,
130142
recvRankCumSum.data_ptr<int>(), localGatherIndices.data_ptr<int>(), gatheredExpertIds.data_ptr<int>(),
131-
gatheredScales.data_ptr<float>(), localExpertIds.data_ptr<int>(), localScales.data_ptr<float>(), stream);
143+
gatheredScalesPtr, localExpertIds.data_ptr<int>(), localScalesPtr, stream);
132144
}
133145

134146
void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output,
@@ -203,14 +215,13 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize)
203215
return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32);
204216
}
205217

206-
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
207-
c10::optional<torch::Tensor>>
208-
moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch::Tensor> expertsStatics,
218+
std::tuple<torch::Tensor, c10::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
219+
torch::Tensor, c10::optional<torch::Tensor>>
220+
moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> scales, c10::optional<torch::Tensor> expertsStatics,
209221
torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount,
210222
int64_t slotCount, int64_t topK)
211223
{
212224
CHECK_INPUT(expertsIds, torch::kInt32);
213-
CHECK_INPUT(scales, torch::kFloat32);
214225
TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4");
215226
TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4");
216227

@@ -219,8 +230,6 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
219230

220231
torch::Tensor preparedLocalExpertIds
221232
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kInt32));
222-
torch::Tensor preparedLocalScales
223-
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32));
224233

225234
torch::Tensor sendRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32));
226235
torch::Tensor RecvRankCountCumSum = torch::empty({epSize}, expertsIds.options().dtype(torch::kInt32));
@@ -240,6 +249,18 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
240249
torch::Tensor sendRankIndices
241250
= torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32));
242251

252+
c10::optional<torch::Tensor> preparedLocalScales;
253+
float* scalesPtr = nullptr;
254+
float* preparedLocalScalesPtr = nullptr;
255+
if (scales.has_value())
256+
{
257+
CHECK_INPUT(scales.value(), torch::kFloat32);
258+
scalesPtr = scales->data_ptr<float>();
259+
preparedLocalScales
260+
= torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32));
261+
preparedLocalScalesPtr = preparedLocalScales->data_ptr<float>();
262+
}
263+
243264
int* localExpertStaticsPtr = nullptr;
244265
int* gatheredExpertStaticsPtr = nullptr;
245266
c10::optional<torch::Tensor> gatheredExpertStatics;
@@ -271,10 +292,10 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
271292
stream);
272293

273294
tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr<int>(),
274-
preparedLocalExpertIds.data_ptr<int>(), scales.data_ptr<float>(), preparedLocalScales.data_ptr<float>(),
275-
localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr<int>(),
276-
sendRankIndices.data_ptr<int>(), RecvRankCountCumSum.data_ptr<int>(), recvRankIndices.data_ptr<int>(),
277-
tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream);
295+
preparedLocalExpertIds.data_ptr<int>(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr,
296+
gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr<int>(), sendRankIndices.data_ptr<int>(),
297+
RecvRankCountCumSum.data_ptr<int>(), recvRankIndices.data_ptr<int>(), tokenCount, maxTokenCountPerRank, topK,
298+
expertCount, slotCount, epRank, epSize, stream);
278299

279300
return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices,
280301
RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics);
@@ -287,8 +308,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
287308
m.def(
288309
"moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int "
289310
"max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, "
290-
"Tensor, "
291-
"Tensor, Tensor)");
311+
"Tensor, Tensor, Tensor)");
292312
}
293313

294314
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
@@ -299,10 +319,9 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
299319
TORCH_LIBRARY_FRAGMENT(trtllm, m)
300320
{
301321
m.def(
302-
"moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor "
303-
"gathered_scales, Tensor local_expert_ids, Tensor local_scales, int max_token_count_per_rank, int "
304-
"expert_count, int "
305-
"top_k, int ep_rank, int ep_size) -> ()");
322+
"moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? "
323+
"gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int "
324+
"expert_count, int top_k, int ep_rank, int ep_size) -> ()");
306325
}
307326

308327
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
@@ -314,8 +333,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
314333
{
315334
m.def(
316335
"moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor "
317-
"recv_rank_cum_sum, "
318-
"Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()");
336+
"recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()");
319337
}
320338

321339
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
@@ -346,12 +364,9 @@ TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
346364
TORCH_LIBRARY_FRAGMENT(trtllm, m)
347365
{
348366
m.def(
349-
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor scales, Tensor? experts_statics, "
350-
"Tensor allWorkspace, int "
351-
"max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int slot_count, int top_k) -> (Tensor, "
352-
"Tensor, Tensor, "
353-
"Tensor, "
354-
"Tensor, Tensor, Tensor, Tensor?)");
367+
"mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, "
368+
"Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int "
369+
"slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)");
355370
}
356371

357372
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

tensorrt_llm/_mnnvl_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@ def mnnvl_moe_expert_static_allgather(
406406
@staticmethod
407407
def mnnvl_moe_alltoallv_prepare(
408408
gathered_target_rank_ids: torch.Tensor,
409-
real_rank_token_count_cumsum: torch.Tensor,
409+
real_rank_token_count_cumsum: Optional[torch.Tensor],
410410
gathered_expert_ids: torch.Tensor,
411-
gathered_scales: torch.Tensor,
411+
gathered_scales: Optional[torch.Tensor],
412412
max_token_count_per_rank: int,
413413
expert_count: int,
414414
top_k: int,
@@ -437,9 +437,15 @@ def mnnvl_moe_alltoallv_prepare(
437437
local_expert_ids = torch.empty(
438438
local_token_allocation_count, top_k, dtype=torch.int32, device=torch.device("cuda")
439439
)
440-
local_scales = torch.empty(
441-
local_token_allocation_count, top_k, dtype=torch.float32, device=torch.device("cuda")
442-
)
440+
if gathered_scales is None:
441+
local_scales = None
442+
else:
443+
local_scales = torch.empty(
444+
local_token_allocation_count,
445+
top_k,
446+
dtype=torch.float32,
447+
device=torch.device("cuda"),
448+
)
443449

444450
torch.ops.trtllm.moe_local_gather(
445451
recv_rank_count_cumsum,

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _(
182182
@torch.library.register_fake("trtllm::moe_comm_prepare_indices")
183183
def _(
184184
gathered_target_rank_ids: torch.Tensor,
185-
real_rank_token_count_cum_sum,
185+
real_rank_token_count_cum_sum: Optional[torch.Tensor],
186186
max_token_count_per_rank: int,
187187
expert_count: int,
188188
top_k: int,
@@ -220,9 +220,9 @@ def _(
220220
recv_rank_cum_sum: torch.Tensor,
221221
local_gather_indices: torch.Tensor,
222222
gathered_expert_ids: torch.Tensor,
223-
gathered_scales: torch.Tensor,
223+
gathered_scales: Optional[torch.Tensor],
224224
local_expert_ids: torch.Tensor,
225-
local_scales: torch.Tensor,
225+
local_scales: Optional[torch.Tensor],
226226
max_token_count_per_rank: int,
227227
expert_count: int,
228228
top_k: int,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _check_configs(self):
125125

126126
if self.apply_router_weight_on_input:
127127
assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing"
128+
128129
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
129130
exclude_kv_cache=True):
130131
if not (self.quant_config.quant_mode.has_nvfp4()
@@ -214,7 +215,6 @@ def forward_chunk(
214215
assert token_selected_experts.dtype == torch.int32
215216

216217
if self.apply_router_weight_on_input:
217-
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
218218
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
219219
x = x * token_final_scales.to(x.dtype)
220220
# TODO: remove this once we have correct fusedmoe kernel ready

0 commit comments

Comments
 (0)