diff --git a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceCommon.h b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceCommon.h index f99bdb3dadc..d3e8063a048 100644 --- a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceCommon.h +++ b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceCommon.h @@ -63,11 +63,12 @@ struct MoeLoadBalanceStatisticInfo // rawDataWindowSize means the size of the raw data window. // e.g. how many steps of raw data are kept in the memory. - int rawDataWindowSize = 1; + // current we keep only the data in current iteration, previous should sum to expertLoadFactor. + static constexpr int rawDataWindowSize = 1; // decayFactor means the decay factor of the raw data per step. - // e.g. if decayFactor is 0.9, then the raw data of expert i will be decayed by 0.9 for each step. - float decayFactor = 0.9f; + // e.g. if decayFactor is 0.95, then the raw data of expert i will be decayed by 0.95 for each step. + float decayFactor = 0.95f; }; // The placement information for GPU diff --git a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu index 4f4bce83ec3..6f67d45ed09 100644 --- a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu +++ b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu @@ -128,6 +128,19 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal) signal->stepAndOwner += MoeLoadBalanceSingleLayerSignal::kCPU; } +template +__global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount) +{ + if (*enabled == 0) + { + return; + } + TYPE oldExpertTokenCount = {0}; + int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x; + TYPE* typedExpertTokenCountPtr = reinterpret_cast(expertTokenCountPtr); + typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount; +} + template __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount) { @@ -151,8 +164,8 @@ __global__ void shiftWindowKernel(MoeLoadBalanceMetaInfo metaInfo, int* const en typedExpertTokenCountPtr[threadIdx.x] = oldExpertTokenCount; } -__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, - int totalEltCount, int* const enabled, int* const gatheredRawExpertIds) +__global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, int* expertTokenCount, int totalEltCount, + int* const enabled, int* const gatheredRawExpertIds) { extern __shared__ int sharedExpertCount[]; if (*enabled == 0) @@ -175,19 +188,19 @@ __global__ void statisticKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceS __syncthreads(); for (int i = threadIdx.x; i < metaInfo.expertCount; i += blockDim.x) { - atomicAdd_system(&statisticInfo.expertTokenCount[i], sharedExpertCount[i]); + atomicAdd_system(&expertTokenCount[i], sharedExpertCount[i]); } } -__global__ void updateLoadFactorKernel( - MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int* const enabled) +__global__ void updateLoadFactorKernel(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, + int* expertTokenCountPtr, int* const enabled) { if (*enabled == 0) { return; } int expertIdx = blockIdx.x * blockDim.x + threadIdx.x; - int expertTokenCount = statisticInfo.expertTokenCount[expertIdx]; + int expertTokenCount = expertTokenCountPtr[expertIdx]; float* loadFactor = statisticInfo.expertLoadFactor; loadFactor[expertIdx] = loadFactor[expertIdx] * statisticInfo.decayFactor + expertTokenCount; } @@ -233,7 +246,7 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic } int sharedMemorySize = metaInfo.expertCount * sizeof(int); statisticKernel<<>>( - metaInfo, statisticInfo, totalEltCount, enabled, gatheredRawExpertIds); + metaInfo, statisticInfo.expertTokenCount, totalEltCount, enabled, gatheredRawExpertIds); } if (isLastStage) @@ -241,8 +254,63 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic // only last stage need update load factor. int threadCount = 128; int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount; - updateLoadFactorKernel<<>>(metaInfo, statisticInfo, enabled); + updateLoadFactorKernel<<>>( + metaInfo, statisticInfo, statisticInfo.expertTokenCount, enabled); + } +} + +void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens, + int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds, + cudaStream_t stream) +{ + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + if (isFirstStage) + { + // shift window and zero expertTokenCount + // only first stage need shift window. + int threadCount = metaInfo.expertCount; + auto* kernelFunc = zeroExpertTokenCountKernel; + if (threadCount % 4 == 0) + { + threadCount /= 4; + kernelFunc = zeroExpertTokenCountKernel; + } + else if (threadCount % 2 == 0) + { + threadCount /= 2; + kernelFunc = zeroExpertTokenCountKernel; + } + dim3 gridDim(1); + dim3 blockDim(threadCount); + void* args[] + = {&metaInfo, static_cast(const_cast(&enabled)), static_cast(&localExpertTokenCount)}; + TLLM_CHECK_WITH_INFO( + threadCount <= 1024, "expertCount=%d is too large and not supported now.", metaInfo.expertCount); + TLLM_CUDA_CHECK(cudaLaunchKernel(kernelFunc, gridDim, blockDim, &args[0], 0, stream)); } + + { + // do the statistic into expertTokenCount and maybe also expertLoadFactor; + int threadCount = 1024; + int totalEltCount = numTotalTokens * metaInfo.topK; + int blockCount = (totalEltCount + threadCount - 1) / threadCount; + if (blockCount > smCount) + { + blockCount = smCount; + } + int sharedMemorySize = metaInfo.expertCount * sizeof(int); + statisticKernel<<>>( + metaInfo, localExpertTokenCount, totalEltCount, enabled, localRawExpertIds); + } +} + +void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, + int* globalExpertTokenCount, int* const enabled, cudaStream_t stream) +{ + int threadCount = 128; + int blockCount = (metaInfo.expertCount + threadCount - 1) / threadCount; + updateLoadFactorKernel<<>>( + metaInfo, statisticInfo, globalExpertTokenCount, enabled); } template diff --git a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h index ae78f4accdd..85acd1fb682 100644 --- a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h +++ b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h @@ -70,6 +70,32 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal); void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, int numTotalTokens, int* const enabled, bool isFirstStage, bool isLastStage, int* const gatheredRawExpertIds, cudaStream_t stream); +// @brief do the statistic based on local device's data +// +// This function is used to launch a kernel to do the statistic for local tokens. +// +// @param metaInfo: the meta info +// @param numTotalTokens: the total number of tokens in localRawExpertIds +// @param localExpertTokenCount: the token count that each expert has for local tokens. +// @param enabled: flag on device memory to indicate if the statistic is enabled +// @param isFirstStage: whether the current stage is the first stage (only first stage need shift window) +// @param isLastStage: whether the current stage is the last stage (only last stage need update load factor) +// @param localRawExpertIds: the gathered raw expert ids, should have shape [numTotalTokens, metaInfo.topK] +void moeHierarchicalStatisticLocalDevice(MoeLoadBalanceMetaInfo metaInfo, int numTotalTokens, + int* localExpertTokenCount, int* const enabled, bool isFirstStage, bool isLastStage, int* const localRawExpertIds, + cudaStream_t stream); + +// @brief update the statistic info based on global info +// +// This function is used to launch a kernel to update the statistic info per iteration. +// +// @param metaInfo: the meta info +// @param statisticInfo: the statistic info +// @param globalExpertTokenCount: the global expert token count, should have shape [metaInfo.expertCount] +// @param enabled: flag on device memory to indicate if the statistic is enabled +void moeHierarchicalStatisticUpdate(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatisticInfo statisticInfo, + int* globalExpertTokenCount, int* const enabled, cudaStream_t stream); + // @brief compute the route // // This function is used to launch a kernel to compute the route based on the token selected experts and the placement diff --git a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp index d2df16424e3..e694249105b 100644 --- a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp +++ b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp @@ -85,6 +85,60 @@ void moeLoadBalanceStatistic(torch::Tensor gatheredRawExpertIds, torch::Tensor e static_cast(isFirstStage), static_cast(isLastStage), gatheredRawExpertIds.data_ptr(), stream); } +void moeHierarchicalStatisticLocalDevice(torch::Tensor localRawExpertIds, torch::Tensor localExpertTokenCount, + torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr, int64_t isFirstStage, int64_t isLastStage) +{ + CHECK_INPUT(localRawExpertIds, torch::kInt32); + CHECK_INPUT(localExpertTokenCount, torch::kInt32); + CHECK_INPUT(enabled, torch::kInt32); + TORCH_CHECK(localRawExpertIds.dim() == 2, "localRawExpertIds must be a 2D tensor"); + TORCH_CHECK(localExpertTokenCount.dim() == 1, "localExpertTokenCount must be a 1D tensor"); + int topK = localRawExpertIds.size(1); + TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor"); + TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element"); + TORCH_CHECK(isFirstStage == 0 || isFirstStage == 1, "isFirstStage must be 0 or 1"); + TORCH_CHECK(isLastStage == 0 || isLastStage == 1, "isLastStage must be 0 or 1"); + TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null"); + + auto* loadBalancer + = reinterpret_cast(singleLayerLoadBalancerPtr); + auto stream = at::cuda::getCurrentCUDAStream(); + + tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo(); + + TORCH_CHECK(localExpertTokenCount.size(0) == metaInfo.expertCount, "localExpertTokenCount should have shape (%d,)", + metaInfo.expertCount); + TORCH_CHECK(topK == metaInfo.topK, "topK must be equal to metaInfo.topK"); + + int numTotalTokens = localRawExpertIds.size(0); + + tensorrt_llm::kernels::moeHierarchicalStatisticLocalDevice(metaInfo, numTotalTokens, + localExpertTokenCount.data_ptr(), enabled.data_ptr(), static_cast(isFirstStage), + static_cast(isLastStage), localRawExpertIds.data_ptr(), stream); +} + +void moeHierarchicalStatisticUpdate( + torch::Tensor globalExpertTokenCount, torch::Tensor enabled, int64_t singleLayerLoadBalancerPtr) +{ + CHECK_INPUT(globalExpertTokenCount, torch::kInt32); + CHECK_INPUT(enabled, torch::kInt32); + TORCH_CHECK(globalExpertTokenCount.dim() == 1, "globalExpertTokenCount must be a 1D tensor"); + TORCH_CHECK(enabled.dim() == 1, "enabled must be a 1D tensor"); + TORCH_CHECK(enabled.size(0) == 1, "enabled must have 1 element"); + TORCH_CHECK(singleLayerLoadBalancerPtr != 0, "singleLayerLoadBalancerPtr must be non-null"); + auto* loadBalancer + = reinterpret_cast(singleLayerLoadBalancerPtr); + auto stream = at::cuda::getCurrentCUDAStream(); + + tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo = loadBalancer->getMetaInfo(); + auto statisticInfo = loadBalancer->getStatisticInfo(); + + TORCH_CHECK(globalExpertTokenCount.size(0) == metaInfo.expertCount, + "globalExpertTokenCount should have shape (%d,)", metaInfo.expertCount); + tensorrt_llm::kernels::moeHierarchicalStatisticUpdate( + metaInfo, *statisticInfo, globalExpertTokenCount.data_ptr(), enabled.data_ptr(), stream); +} + torch::Tensor moeLoadBalanceRouting( torch::Tensor tokenSelectedExperts, bool offsetByEpRank, int64_t singleLayerLoadBalancerPtr) { @@ -182,6 +236,31 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_load_balance_statistic", &torch_ext::moeLoadBalanceStatistic); } +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "moe_hierarchical_statistic_local_device(Tensor local_raw_expert_ids, Tensor local_expert_token_count, Tensor " + "enabled, int " + "single_layer_load_balancer_ptr, int is_first_stage, int is_last_stage) -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("moe_hierarchical_statistic_local_device", &torch_ext::moeHierarchicalStatisticLocalDevice); +} + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "moe_hierarchical_statistic_update(Tensor global_expert_token_count, Tensor enabled, int " + "single_layer_load_balancer_ptr) -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("moe_hierarchical_statistic_update", &torch_ext::moeHierarchicalStatisticUpdate); +} + TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( diff --git a/cpp/tests/kernels/moeLoadBalanceKernelTest.cpp b/cpp/tests/kernels/moeLoadBalanceKernelTest.cpp index 0241645dff3..6844fb69c71 100644 --- a/cpp/tests/kernels/moeLoadBalanceKernelTest.cpp +++ b/cpp/tests/kernels/moeLoadBalanceKernelTest.cpp @@ -110,7 +110,6 @@ struct MoeLoadBalanceTestParam bool isFirstStage; bool isLastStage; float decayFactor; - int rawDataWindowSize; }; class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam @@ -126,14 +125,13 @@ class MoeLoadBalanceStatisticKernelTest : public ::testing::TestWithParam= 0; --windowIdx) + for (int windowIdx = mStatisticInfo.rawDataWindowSize - 1; windowIdx >= 0; --windowIdx) { if (windowIdx > 0) { @@ -305,7 +303,7 @@ TEST_P(MoeLoadBalanceStatisticKernelTest, TestStatistics) EXPECT_NEAR(mHostExpertLoadFactor[i], mExpectedLoadFactor[i], 1e-6) << "Expert " << i << " load factor mismatch"; } - for (int i = 0; i < param.expertCount * param.rawDataWindowSize; ++i) + for (int i = 0; i < param.expertCount * mStatisticInfo.rawDataWindowSize; ++i) { EXPECT_EQ(mHostExpertTokenCount[i], mExpectedExpertTokenCount[i]) << "Expert " << i << " token count mismatch"; } @@ -323,8 +321,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati /* maxTokenCountPerRank */ 128, /* isFirstStage */ true, /* isLastStage */ true, - /* decayFactor */ 0.9f, - /* rawDataWindowSize */ 3}, + /* decayFactor */ 0.9f}, // large scale test scenarios MoeLoadBalanceTestParam{/* expertCount */ 64, /* topK */ 4, @@ -334,8 +331,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceStatisticKernelTests, MoeLoadBalanceStati /* maxTokenCountPerRank */ 512, /* isFirstStage */ false, /* isLastStage */ true, - /* decayFactor */ 0.95f, - /* rawDataWindowSize */ 5} // can add more test scenarios + /* decayFactor */ 0.95f} // can add more test scenarios )); class MoeLoadBalanceRouteKernelTest : public ::testing::TestWithParam @@ -601,8 +597,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern /* maxTokenCountPerRank */ 128, /* isFirstStage */ true, /* isLastStage */ true, - /* decayFactor */ 0.9f, - /* rawDataWindowSize */ 3}, + /* decayFactor */ 0.9f}, // large scale test scenarios MoeLoadBalanceTestParam{/* expertCount */ 256, /* topK */ 8, @@ -612,8 +607,7 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern /* maxTokenCountPerRank */ 5000, /* isFirstStage */ false, /* isLastStage */ true, - /* decayFactor */ 0.95f, - /* rawDataWindowSize */ 5}, + /* decayFactor */ 0.95f}, // edge case: single rank MoeLoadBalanceTestParam{/* expertCount */ 16, /* topK */ 2, @@ -623,5 +617,4 @@ INSTANTIATE_TEST_SUITE_P(MoeLoadBalanceRouteKernelTests, MoeLoadBalanceRouteKern /* maxTokenCountPerRank */ 64, /* isFirstStage */ true, /* isLastStage */ true, - /* decayFactor */ 0.9f, - /* rawDataWindowSize */ 1})); + /* decayFactor */ 0.9f})); diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 05489a71112..292605c55ad 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -263,9 +263,22 @@ def _(single_layer_load_balancer_ptr: int): pass @torch.library.register_fake("trtllm::moe_load_balance_statistic") - def _(single_layer_load_balancer_ptr: int, - gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor, - is_first_stage: bool, is_last_stage: bool): + def _(gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor, + single_layer_load_balancer_ptr: int, is_first_stage: bool, + is_last_stage: bool): + pass + + @torch.library.register_fake( + "trtllm::moe_hierarchical_statistic_local_device") + def _(local_raw_expert_ids: torch.Tensor, + local_expert_token_count: torch.Tensor, enabled: torch.Tensor, + single_layer_load_balancer_ptr: int, is_first_stage: bool, + is_last_stage: bool): + pass + + @torch.library.register_fake("trtllm::moe_hierarchical_statistic_update") + def _(global_expert_token_count: torch.Tensor, enabled: torch.Tensor, + single_layer_load_balancer_ptr: int): pass @torch.library.register_fake("trtllm::moe_load_balance_routing") diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 7cbe3a33628..2d0c4c00c8b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -64,7 +64,7 @@ def create_moe( moe_load_balancer = get_moe_load_balancer() if moe_load_balancer is not None: - assert moe_cls == CutlassFusedMoE, "MoE Load Balance is only supported in CutlassFusedMoE now." + assert moe_cls == WideEPMoE, "MoE Load Balance is only supported in WideEPMoE now." if moe_cls == TRTLLMGenFusedMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE." diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 0d4e99d50e0..3506845133b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -359,32 +359,39 @@ def forward_chunk( ) and is_first_call: self.layer_load_balancer.maybe_cudagraph_done_wait() - need_statistic = False + loadbalancer_local_statistic_info = None + gathered_loadbalancer_local_statistic_info = None if self.layer_load_balancer is None: token_selected_slots = token_selected_experts else: + if not self.layer_load_balancer.is_static_routing(): + self.layer_load_balancer.local_statistic( + token_selected_experts, + is_first_stage=is_first_call, + is_last_stage=is_last_call) token_selected_slots = self.layer_load_balancer.route( token_selected_experts, self.use_dp) if not self.layer_load_balancer.is_static_routing(): - need_statistic = True + # split into two part to get possible overlap with load balancer routing + if is_last_call: + loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( + ) # If load balancer is disabled, the statistics are collected from expert IDs. # If load balancer is enabled, the statistics are collected from expert slot IDs. ExpertStatistic.set_layer(self.layer_idx) ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) - token_selected_experts_for_statistic = token_selected_experts if need_statistic else None - if self.enable_alltoall: if self.alltoall_method_type == AlltoallMethodType.MNNVL: token_count = x.shape[0] alltoall_info = None - x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \ + x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \ self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, x, token_selected_slots, token_final_scales, - token_selected_experts_for_statistic) + loadbalancer_local_statistic_info) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if not self.use_postquant_alltoall: x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ @@ -427,6 +434,7 @@ def forward_chunk( ) x_sf = None + sf_swizzle = True if self.has_any_quant: if self.has_fp8_qdq: x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( @@ -439,11 +447,14 @@ def forward_chunk( # note: we use uint8 to store 2 fp4 values x_col = x.shape[1] * 2 else: + sf_swizzle = not self.use_postquant_alltoall x_row = x.shape[0] x_col = x.shape[1] x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, - False) + False, sf_swizzle) + if self.use_postquant_alltoall: + x_sf = x_sf.view((x_row, -1)) elif self.has_deepseek_fp8_block_scales: use_deepseek_fp8_block_scale = True @@ -457,24 +468,31 @@ def forward_chunk( if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather( ) and not self.enable_alltoall: - x, x_sf, token_selected_slots, token_final_scales, token_selected_experts_for_statistic = allgather( + x, x_sf, token_selected_slots, token_final_scales = allgather( [ - x, x_sf, token_selected_slots, token_final_scales, - token_selected_experts_for_statistic + x, + x_sf, + token_selected_slots, + token_final_scales, ], self.mapping, dim=0, sizes=None if use_dp_padding else all_rank_num_tokens) + # use separate allgather since doesn't have sizes, can be optimized but in allgather path it is OK + if is_last_call: + gathered_loadbalancer_local_statistic_info = allgather( + loadbalancer_local_statistic_info, self.mapping, dim=0) # Fp4 gemm has extra scaling factor if x_sf is not None: x_sf = reswizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( - ): - self.layer_load_balancer.statistic( - token_selected_experts_for_statistic, is_first_call, - is_last_call) + ) and is_last_call: + gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( + (self.mapping.moe_ep_size, self.num_experts)) + self.layer_load_balancer.update_statistic( + gathered_loadbalancer_local_statistic_info) if self.smart_router and not cutlass_min_latency_mode: ep_size = self.cluster_size @@ -501,10 +519,15 @@ def forward_chunk( if self.use_postquant_alltoall: if self.alltoall_method_type == AlltoallMethodType.MNNVL: x, x_sf = self.alltoall_postquant_dispatch( - x, x_sf, x_row, x_col, alltoall_info) + x, + x_sf, + x_row, + x_col, + alltoall_info, + is_sf_swizzle=sf_swizzle) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if x_sf is not None: - if self.has_nvfp4: + if self.has_nvfp4 and sf_swizzle: x_sf = unswizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) # Adapter between `x_sf` and DeepEP @@ -760,35 +783,23 @@ def alltoall_prepare_maybe_dispatch( self, all_rank_num_tokens: list, x: torch.Tensor, token_selected_slots: torch.Tensor, token_final_scales: torch.Tensor, - token_selected_experts_for_statistic: Optional[torch.Tensor]): + local_statistic_tensor: Optional[torch.Tensor]): top_k = self.routing_method.experts_per_token # gather router info max_num_token = max(all_rank_num_tokens) - token_selected_slots = torch.nn.functional.pad( - token_selected_slots, - (0, 0, 0, max_num_token - token_selected_slots.shape[0]), - 'constant', self.num_slots) - token_selected_experts_for_statistic = torch.nn.functional.pad( - token_selected_experts_for_statistic, - (0, 0, 0, - max_num_token - token_selected_experts_for_statistic.shape[0]), - 'constant', self.num_experts - ) if token_selected_experts_for_statistic is not None else None - token_final_scales = torch.nn.functional.pad( - token_final_scales, - (0, 0, 0, max_num_token - token_final_scales.shape[0])) - gathered_token_selected_slots, gathered_token_final_scales, gathered_token_selected_experts_for_statistic = allgather( - [ - token_selected_slots, token_final_scales, - token_selected_experts_for_statistic - ], + if max_num_token > token_selected_slots.shape[0]: + token_selected_slots = torch.nn.functional.pad( + token_selected_slots, + (0, 0, 0, max_num_token - token_selected_slots.shape[0]), + 'constant', self.num_slots) + if max_num_token > token_final_scales.shape[0]: + token_final_scales = torch.nn.functional.pad( + token_final_scales, + (0, 0, 0, max_num_token - token_final_scales.shape[0])) + gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather( + [token_selected_slots, token_final_scales, local_statistic_tensor], self.mapping, dim=0) - if gathered_token_selected_experts_for_statistic is not None: - gathered_token_selected_experts_for_statistic = torch.flatten( - gathered_token_selected_experts_for_statistic.contiguous(), - start_dim=0, - end_dim=-2) gathered_token_selected_slots = torch.flatten( gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2) gathered_token_final_scales = torch.flatten( @@ -808,17 +819,21 @@ def alltoall_prepare_maybe_dispatch( self.alltoall_workspace, self.ep_rank, self.ep_size) - return x, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic, alltoall_info + return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info - def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, - x_row: int, x_col: int, - alltoall_info: MoEAlltoallInfo): + def alltoall_postquant_dispatch(self, + x: torch.Tensor, + x_sf: torch.Tensor, + x_row: int, + x_col: int, + alltoall_info: MoEAlltoallInfo, + is_sf_swizzle: bool = True): x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, self.alltoall_workspace, self.ep_rank, self.ep_size) if x_sf is not None: - if self.has_nvfp4: + if self.has_nvfp4 and is_sf_swizzle: x_sf = unswizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index 67dbd56eabf..b611f3f97a5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -302,10 +302,14 @@ def __init__( self.load_expert_ids = list(range(load_expert_start, load_expert_end)) self.statistic_flag_tensor = None + self.local_statistic_tensor = None self.cudagraph_stream = None self.cudagraph_event = None + self.statistic_stream = None + self.statistic_event = None + def get_layer_idx(self): return self.single_layer_load_balancer_impl.get_layer_id() @@ -474,6 +478,12 @@ def set_cpu_stage(self): self.statistic_flag_tensor = None if is_graph_capturing(): assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage." + assert self.statistic_event is not None + assert self.statistic_stream is not None + # wait statistic update done + self.statistic_event.wait() + self.statistic_event = None + self.statistic_stream = None current_stream_event = torch.cuda.Event() current_stream_event.record(torch.cuda.current_stream()) with torch.cuda.stream(self.cudagraph_stream): @@ -511,6 +521,87 @@ def statistic(self, gathered_raw_expert_ids: torch.Tensor, self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage) + def local_statistic(self, local_raw_expert_ids: torch.Tensor, + is_first_stage: bool, is_last_stage: bool): + """ + Perform local statistics on the expert IDs. + + Args: + local_raw_expert_ids: The gathered raw expert IDs from all ranks + is_first_stage: Whether this is the first stage + is_last_stage: Whether this is the last stage + """ + if self.updates_enabled: + assert isinstance(self.statistic_flag_tensor, torch.Tensor) + if is_first_stage: + assert self.local_statistic_tensor is None + self.local_statistic_tensor = torch.empty( + (self.expert_count, ), + dtype=torch.int32, + device=torch.device('cuda')) + if is_graph_capturing(): + self.statistic_event = torch.cuda.Event() + self.statistic_stream = torch.cuda.Stream() + current_stream_event = torch.cuda.Event() + current_stream_event.record(torch.cuda.current_stream()) + with torch.cuda.stream(self.statistic_stream): + current_stream_event.wait() + torch.ops.trtllm.moe_hierarchical_statistic_local_device( + local_raw_expert_ids, self.local_statistic_tensor, + self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr, is_first_stage, + is_last_stage) + self.statistic_event.record(self.statistic_stream) + else: + torch.ops.trtllm.moe_hierarchical_statistic_local_device( + local_raw_expert_ids, self.local_statistic_tensor, + self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr, is_first_stage, + is_last_stage) + + def get_local_statistic_tensor(self): + """ + Get the local statistic tensor. Should perform allreduce on it and then call update_statistic + Returns: + The local statistic tensor if using statistic else None + """ + if self.updates_enabled: + assert self.local_statistic_tensor is not None + if is_graph_capturing(): + assert self.statistic_event is not None + assert self.statistic_stream is not None + self.statistic_event.wait() + return self.local_statistic_tensor + return None + + def update_statistic(self, gathered_local_statistic_tensor: torch.Tensor): + """ + Perform update with global statistics. + + Args: + gathered_local_statistic_tensor: gathered local statistics info, should have shape (world_size, self.expert_count) + """ + if self.updates_enabled: + assert isinstance(self.statistic_flag_tensor, torch.Tensor) + + def _update_statistic(): + global_statistic_info = torch.sum( + gathered_local_statistic_tensor, dim=0, dtype=torch.int32) + torch.ops.trtllm.moe_hierarchical_statistic_update( + global_statistic_info, self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr) + + if is_graph_capturing(): + current_stream_event = torch.cuda.Event() + current_stream_event.record(torch.cuda.current_stream()) + with torch.cuda.stream(self.statistic_stream): + current_stream_event.wait() + _update_statistic() + self.statistic_event.record(self.statistic_stream) + else: + _update_statistic() + self.local_statistic_tensor = None + def route(self, token_selected_experts: torch.Tensor, offset_by_ep_rank: bool = False) -> torch.Tensor: