diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index c106002e2c8..8fa46a11e87 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -161,17 +161,11 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, bool min_latency_mode, torch::optional> profile_ids) { - // Free the profile workspace to save memory - if (mProfileWorkspace != nullptr) - { - auto const cu_free_status = cudaFree(mProfileWorkspace); - TORCH_CHECK( - cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe."); - mProfileWorkspace = nullptr; - } - std::lock_guard lock(mMutex); + // Free the profile workspace to save memory + freeProfileWorkspace(); + CHECK_INPUT(input, mActivationDtype) CHECK_INPUT(token_selected_experts, at::ScalarType::Int) if (token_final_scales) @@ -248,6 +242,9 @@ class FusedMoeRunner : public torch::CustomClassHolder { std::lock_guard lock(mMutex); + // Free the profile workspace to save memory + freeProfileWorkspace(); + CHECK_INPUT(input, mActivationDtype) CHECK_INPUT(token_selected_experts, at::ScalarType::Int) if (token_final_scales) @@ -375,13 +372,7 @@ class FusedMoeRunner : public torch::CustomClassHolder hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, parallelism_config); - if (mProfileWorkspace != nullptr) - { - auto const cu_free_status = cudaFree(mProfileWorkspace); - TORCH_CHECK(cu_free_status == cudaSuccess, - "Can't free profile workspace for MoE GEMM profile during memory reallocation."); - mProfileWorkspace = nullptr; - } + freeProfileWorkspace(); size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows); auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size); TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile."); @@ -416,6 +407,17 @@ class FusedMoeRunner : public torch::CustomClassHolder using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector mAllProfiles; + void freeProfileWorkspace() + { + if (mProfileWorkspace != nullptr) + { + auto const cu_free_status = cudaFree(mProfileWorkspace); + TORCH_CHECK(cu_free_status == cudaSuccess, + "Can't free profile workspace for MoE GEMM profile during memory reallocation."); + mProfileWorkspace = nullptr; + } + } + void setRunnerProfiles(torch::optional> profile_ids) { if (mUseFp8BlockScaling) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index dda4bcc3283..98b0a3c2a7d 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -6,6 +6,7 @@ from ..autotuner import AutoTuner, TunableRunner, TuningConfig from ..utils import (get_last_power_of_2_num_tokens_buckets, + get_power_of_2_num_tokens_buckets, last_positive_power_of_2, next_positive_power_of_2) @@ -113,6 +114,7 @@ def fused_moe( ep_rank: int = 0, use_fp8_block_scaling: bool = False, min_latency_mode: bool = False, + tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: tuner = AutoTuner.get() @@ -120,8 +122,8 @@ def fused_moe( # TODO: only profile for min_latency_mode = False due to the error in the moe_kernels tuning_config = TuningConfig(dynamic_tensors=( # input, dim 0, all valid buckets, map a seq_len to power of 2 bucket index - (0, 0, ((16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, - 2, 1), next_positive_power_of_2)), + (0, 0, (get_power_of_2_num_tokens_buckets(tune_max_num_tokens), + next_positive_power_of_2)), # min_latency_tensor, dim 0, (0 for False, 1 for True), map to it self (2, 0, ((0, ), lambda x: x)), )) @@ -194,6 +196,7 @@ def _( ep_rank: int = 0, use_fp8_block_scaling: bool = False, min_latency_mode: bool = False, + tune_max_num_tokens: int = 8192, ): seq_len = input.shape[0] hidden_size = fc2_expert_weights.shape[1] diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index a1bed31c5e2..545f73f8e66 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -653,6 +653,7 @@ def forward_chunk( ep_rank=self.ep_rank, use_fp8_block_scaling=use_fp8_block_scaling, min_latency_mode=min_latency_mode, + tune_max_num_tokens=self.tune_max_num_tokens, ) if min_latency_mode: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 8984b3700ad..dca80671396 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -197,7 +197,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: num_token_buckets.append(m) m //= 2 - return num_token_buckets + return tuple(num_token_buckets) def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: