diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index c106002e2c8..8fa46a11e87 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -161,17 +161,11 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, bool min_latency_mode, torch::optional> profile_ids) { - // Free the profile workspace to save memory - if (mProfileWorkspace != nullptr) - { - auto const cu_free_status = cudaFree(mProfileWorkspace); - TORCH_CHECK( - cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe."); - mProfileWorkspace = nullptr; - } - std::lock_guard lock(mMutex); + // Free the profile workspace to save memory + freeProfileWorkspace(); + CHECK_INPUT(input, mActivationDtype) CHECK_INPUT(token_selected_experts, at::ScalarType::Int) if (token_final_scales) @@ -248,6 +242,9 @@ class FusedMoeRunner : public torch::CustomClassHolder { std::lock_guard lock(mMutex); + // Free the profile workspace to save memory + freeProfileWorkspace(); + CHECK_INPUT(input, mActivationDtype) CHECK_INPUT(token_selected_experts, at::ScalarType::Int) if (token_final_scales) @@ -375,13 +372,7 @@ class FusedMoeRunner : public torch::CustomClassHolder hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, parallelism_config); - if (mProfileWorkspace != nullptr) - { - auto const cu_free_status = cudaFree(mProfileWorkspace); - TORCH_CHECK(cu_free_status == cudaSuccess, - "Can't free profile workspace for MoE GEMM profile during memory reallocation."); - mProfileWorkspace = nullptr; - } + freeProfileWorkspace(); size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows); auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size); TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile."); @@ -416,6 +407,17 @@ class FusedMoeRunner : public torch::CustomClassHolder using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector mAllProfiles; + void freeProfileWorkspace() + { + if (mProfileWorkspace != nullptr) + { + auto const cu_free_status = cudaFree(mProfileWorkspace); + TORCH_CHECK(cu_free_status == cudaSuccess, + "Can't free profile workspace for MoE GEMM profile during memory reallocation."); + mProfileWorkspace = nullptr; + } + } + void setRunnerProfiles(torch::optional> profile_ids) { if (mUseFp8BlockScaling) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 421d32bfbbc..2255799b100 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -120,13 +120,15 @@ def fused_moe( # TODO: only profile for min_latency_mode = False due to the error in the moe_kernels tuning_config = TuningConfig(dynamic_tensors=( # input, dim 0, all valid buckets, map a seq_len to power of 2 bucket index - (0, 0, ((16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, - 2, 1), next_positive_power_of_2)), + (0, 0, ((8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1), + next_positive_power_of_2)), # min_latency_tensor, dim 0, (0 for False, 1 for True), map to it self (2, 0, ((0, ), lambda x: x)), )) - min_latency_tensor = torch.empty(1) if min_latency_mode else torch.empty(0) + # TODO: set min_latency_mode always to False due to the error in the moe_kernels + min_latency_tensor = torch.empty(0) + # allocate workspace for profiling moe_runner = MoERunner( x_dtype=input.dtype, diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 8984b3700ad..dca80671396 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -197,7 +197,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: num_token_buckets.append(m) m //= 2 - return num_token_buckets + return tuple(num_token_buckets) def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: