@@ -87,15 +87,13 @@ moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional<torch
8787}
8888
8989void moeLocalGatherOp (torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds,
90- torch::Tensor gatheredScales, torch::Tensor localExpertIds, torch::Tensor localScales, int64_t maxTokenCountPerRank ,
91- int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
90+ c10::optional< torch::Tensor> gatheredScales, torch::Tensor localExpertIds, c10::optional< torch::Tensor> localScales,
91+ int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize)
9292{
9393 CHECK_INPUT (recvRankCumSum, torch::kInt32 );
9494 CHECK_INPUT (localGatherIndices, torch::kInt32 );
9595 CHECK_INPUT (gatheredExpertIds, torch::kInt32 );
96- CHECK_INPUT (gatheredScales, torch::kFloat32 );
9796 CHECK_INPUT (localExpertIds, torch::kInt32 );
98- CHECK_INPUT (localScales, torch::kFloat32 );
9997
10098 TORCH_CHECK (maxTokenCountPerRank > 0 , " maxTokenCountPerRank must be greater than 0" );
10199 TORCH_CHECK (expertCount > 0 , " expertCount must be greater than 0" );
@@ -107,17 +105,31 @@ void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherInd
107105 TORCH_CHECK (recvRankCumSum.size (0 ) == epSize, " recvRankCumSum must have epSize elements" );
108106 TORCH_CHECK (localGatherIndices.dim () == 1 , " localGatherIndices must be a 1D tensor" );
109107 TORCH_CHECK (gatheredExpertIds.dim () == 2 , " gatheredExpertIds must be a 2D tensor" );
110- TORCH_CHECK (gatheredScales.dim () == 2 , " gatheredScales must be a 2D tensor" );
111108 TORCH_CHECK (localExpertIds.dim () == 2 , " localExpertIds must be a 2D tensor" );
112- TORCH_CHECK (localScales.dim () == 2 , " localScales must be a 2D tensor" );
113109 TORCH_CHECK (gatheredExpertIds.size (1 ) == topK, " gatheredExpertIds must have topK columns" );
114- TORCH_CHECK (gatheredScales.size (1 ) == topK, " gatheredScales must have topK columns" );
115110 TORCH_CHECK (localExpertIds.size (1 ) == topK, " localExpertIds must have topK columns" );
116- TORCH_CHECK (localScales.size (1 ) == topK, " localScales must have topK columns" );
117111
118112 int localMaxTokenCount = static_cast <int >(localGatherIndices.size (0 ));
119113 TORCH_CHECK (localExpertIds.size (0 ) == localMaxTokenCount, " localExpertIds must have localMaxTokenCount rows" );
120- TORCH_CHECK (localScales.size (0 ) == localMaxTokenCount, " localScales must have localMaxTokenCount rows" );
114+
115+ TORCH_CHECK (gatheredScales.has_value () == localScales.has_value (),
116+ " gatheredScales and localScales must be both valid or both invalid" );
117+ float const * gatheredScalesPtr = nullptr ;
118+ float * localScalesPtr = nullptr ;
119+ if (gatheredScales.has_value ())
120+ {
121+ CHECK_INPUT (gatheredScales.value (), torch::kFloat32 );
122+ CHECK_INPUT (localScales.value (), torch::kFloat32 );
123+
124+ TORCH_CHECK (gatheredScales->dim () == 2 , " gatheredScales must be a 2D tensor" );
125+ TORCH_CHECK (gatheredScales->size (1 ) == topK, " gatheredScales must have topK columns" );
126+ TORCH_CHECK (localScales->dim () == 2 , " localScales must be a 2D tensor" );
127+ TORCH_CHECK (localScales->size (1 ) == topK, " localScales must have topK columns" );
128+ TORCH_CHECK (localScales->size (0 ) == localMaxTokenCount, " localScales must have localMaxTokenCount rows" );
129+
130+ gatheredScalesPtr = gatheredScales->data_ptr <float >();
131+ localScalesPtr = localScales->data_ptr <float >();
132+ }
121133
122134 auto stream = at::cuda::getCurrentCUDAStream ();
123135
@@ -128,7 +140,7 @@ void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherInd
128140 tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast <int >(epSize), static_cast <int >(epRank)};
129141 tensorrt_llm::kernels::moeLocalGather (worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount,
130142 recvRankCumSum.data_ptr <int >(), localGatherIndices.data_ptr <int >(), gatheredExpertIds.data_ptr <int >(),
131- gatheredScales. data_ptr < float >() , localExpertIds.data_ptr <int >(), localScales. data_ptr < float >() , stream);
143+ gatheredScalesPtr , localExpertIds.data_ptr <int >(), localScalesPtr , stream);
132144}
133145
134146void moeCommOp (torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output,
@@ -203,14 +215,13 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize)
203215 return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize (epSize32);
204216}
205217
206- std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
207- c10::optional<torch::Tensor>>
208- moePrepareOp (torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch::Tensor> expertsStatics,
218+ std::tuple<torch::Tensor, c10::optional< torch::Tensor> , torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
219+ torch::Tensor, c10::optional<torch::Tensor>>
220+ moePrepareOp (torch::Tensor expertsIds, c10::optional< torch::Tensor> scales, c10::optional<torch::Tensor> expertsStatics,
209221 torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount,
210222 int64_t slotCount, int64_t topK)
211223{
212224 CHECK_INPUT (expertsIds, torch::kInt32 );
213- CHECK_INPUT (scales, torch::kFloat32 );
214225 TORCH_CHECK (expertCount % 4 == 0 , " expertCount must be divisible by 4" );
215226 TORCH_CHECK (slotCount % 4 == 0 , " slotCount must be divisible by 4" );
216227
@@ -219,8 +230,6 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
219230
220231 torch::Tensor preparedLocalExpertIds
221232 = torch::empty ({maxTokenCountPerRank * epSize, topK}, expertsIds.options ().dtype (torch::kInt32 ));
222- torch::Tensor preparedLocalScales
223- = torch::empty ({maxTokenCountPerRank * epSize, topK}, expertsIds.options ().dtype (torch::kFloat32 ));
224233
225234 torch::Tensor sendRankCountCumSum = torch::empty ({epSize}, expertsIds.options ().dtype (torch::kInt32 ));
226235 torch::Tensor RecvRankCountCumSum = torch::empty ({epSize}, expertsIds.options ().dtype (torch::kInt32 ));
@@ -240,6 +249,18 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
240249 torch::Tensor sendRankIndices
241250 = torch::empty ({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options ().dtype (torch::kInt32 ));
242251
252+ c10::optional<torch::Tensor> preparedLocalScales;
253+ float * scalesPtr = nullptr ;
254+ float * preparedLocalScalesPtr = nullptr ;
255+ if (scales.has_value ())
256+ {
257+ CHECK_INPUT (scales.value (), torch::kFloat32 );
258+ scalesPtr = scales->data_ptr <float >();
259+ preparedLocalScales
260+ = torch::empty ({maxTokenCountPerRank * epSize, topK}, expertsIds.options ().dtype (torch::kFloat32 ));
261+ preparedLocalScalesPtr = preparedLocalScales->data_ptr <float >();
262+ }
263+
243264 int * localExpertStaticsPtr = nullptr ;
244265 int * gatheredExpertStaticsPtr = nullptr ;
245266 c10::optional<torch::Tensor> gatheredExpertStatics;
@@ -271,10 +292,10 @@ moePrepareOp(torch::Tensor expertsIds, torch::Tensor scales, c10::optional<torch
271292 stream);
272293
273294 tensorrt_llm::kernels::moe_prepare::allToAllMetadata (expertsIds.data_ptr <int >(),
274- preparedLocalExpertIds.data_ptr <int >(), scales. data_ptr < float >(), preparedLocalScales. data_ptr < float >() ,
275- localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr <int >(),
276- sendRankIndices. data_ptr < int >(), RecvRankCountCumSum.data_ptr <int >(), recvRankIndices.data_ptr <int >(),
277- tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream);
295+ preparedLocalExpertIds.data_ptr <int >(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr ,
296+ gatheredExpertStaticsPtr, workspace, sendRankCountCumSum. data_ptr < int >(), sendRankIndices .data_ptr <int >(),
297+ RecvRankCountCumSum.data_ptr <int >(), recvRankIndices.data_ptr <int >(), tokenCount, maxTokenCountPerRank, topK ,
298+ expertCount, slotCount, epRank, epSize, stream);
278299
279300 return std::make_tuple (preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices,
280301 RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics);
@@ -287,8 +308,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
287308 m.def (
288309 " moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int "
289310 " max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, "
290- " Tensor, "
291- " Tensor, Tensor)" );
311+ " Tensor, Tensor, Tensor)" );
292312}
293313
294314TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
@@ -299,10 +319,9 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
299319TORCH_LIBRARY_FRAGMENT (trtllm, m)
300320{
301321 m.def (
302- " moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor "
303- " gathered_scales, Tensor local_expert_ids, Tensor local_scales, int max_token_count_per_rank, int "
304- " expert_count, int "
305- " top_k, int ep_rank, int ep_size) -> ()" );
322+ " moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? "
323+ " gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int "
324+ " expert_count, int top_k, int ep_rank, int ep_size) -> ()" );
306325}
307326
308327TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
@@ -314,8 +333,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
314333{
315334 m.def (
316335 " moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor "
317- " recv_rank_cum_sum, "
318- " Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()" );
336+ " recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()" );
319337}
320338
321339TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
@@ -346,12 +364,9 @@ TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m)
346364TORCH_LIBRARY_FRAGMENT (trtllm, m)
347365{
348366 m.def (
349- " mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor scales, Tensor? experts_statics, "
350- " Tensor allWorkspace, int "
351- " max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int slot_count, int top_k) -> (Tensor, "
352- " Tensor, Tensor, "
353- " Tensor, "
354- " Tensor, Tensor, Tensor, Tensor?)" );
367+ " mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, "
368+ " Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int "
369+ " slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)" );
355370}
356371
357372TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
0 commit comments