diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h index 3a5153259fe..6626b18e388 100644 --- a/cpp/include/tensorrt_llm/common/cudaUtils.h +++ b/cpp/include/tensorrt_llm/common/cudaUtils.h @@ -38,6 +38,7 @@ #include #include #include +#include #ifndef _WIN32 // Linux #include #endif // not WIN32 @@ -432,6 +433,21 @@ inline int getMaxSharedMemoryPerBlockOptin() return nByteMaxSharedMemoryPerBlockOptin; } +template +inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize) +{ + static std::unordered_map cache; + auto it = cache.find(kernel); + if (it != cache.end()) + { + return it->second; + } + int numBlocks; + check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize)); + cache[kernel] = numBlocks; + return numBlocks; +} + template inline size_t divUp(T1 const& a, T2 const& b) { diff --git a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu index 90664105008..3fa5fae3afa 100644 --- a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu +++ b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu @@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out int64_t const kCopyPerToken = hidden_size / kElemPerCopy; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); + cudaGridDependencySynchronize(); #endif int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; @@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + cudaTriggerProgrammaticLaunchCompletion(); #endif } @@ -141,12 +141,12 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const } #endif + auto kernel = &moePermuteKernel; static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); - int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens); + int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); int32_t const threads = kThreadsPerBlock; - auto kernel = &moePermuteKernel; - cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; @@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o int32_t const token_idx = blockIdx.x; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); + cudaGridDependencySynchronize(); #endif auto* dst_ptr = reinterpret_cast(output) + token_idx * kCopyPerToken; @@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + cudaTriggerProgrammaticLaunchCompletion(); #endif } @@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16); #endif #undef INSTANTIATE_MOE_UNPERMUTE +template +__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit, + int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size) +{ + int32_t constexpr kElemPerCopy = elemPerCopy(); + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + + InputType rmem[kElemPerCopy]; +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmem[j] = 0; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) + { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) + { + continue; + } + int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; + int32_t const token_idx = expanded_idx / top_k; + int32_t const topk_idx = expanded_idx % top_k; + + bool is_first_in_topk = true; + for (int32_t k = 0; k < topk_idx; k++) + { + if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0) + { + is_first_in_topk = false; + break; + } + } + if (!is_first_in_topk) + { + continue; + } + + auto* dst_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) + { + dst_ptr[i] = *reinterpret_cast(rmem); + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx, + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, + int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, + cudaStream_t stream) +{ + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); + + auto kernel = &moeOutputMemsetKernel; + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size); +} + +#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \ + template void moeOutputMemset(InputType * input, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \ + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \ + int32_t const top_k, int32_t const tile_size, cudaStream_t stream) + +INSTANTIATE_MOE_OUTPUT_MEMSET(half); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16); +#endif +#undef INSTANTIATE_MOE_OUTPUT_MEMSET + template __global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf, @@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output, ActFn act{}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); + cudaGridDependencySynchronize(); #endif float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0]; @@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output, } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + cudaTriggerProgrammaticLaunchCompletion(); #endif } @@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob } #endif - static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); - int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens); - int32_t const threads = kThreadsPerBlock; - auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, @@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob }; auto kernel = get_act_kernel(activation_params.activation_type); + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; diff --git a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h index 0659b4c78f6..2bd356e3b04 100644 --- a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h +++ b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h @@ -32,6 +32,12 @@ void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t co TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, cudaStream_t stream); +template +void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx, + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, + int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, + cudaStream_t stream); + template void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index f1e74eda4aa..76c7c585866 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int64_t num_padding_tokens = 0; #endif - static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); - // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens)); - int64_t const threads = EXPAND_THREADS_PER_BLOCK; - auto func = [&]() { #ifdef ENABLE_FP8 @@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, } }(); + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0); + int32_t const blocks + = std::min(smCount * maxBlocksPerSM, static_cast(std::max(num_rows * k, num_padding_tokens))); + int32_t const threads = EXPAND_THREADS_PER_BLOCK; + cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; @@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro if (parallelism_config.ep_size > 1 && enable_alltoall) { // If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros. - static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); - // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = smCount * 8; - int64_t const threads = FINALIZE_THREADS_PER_BLOCK; - config.gridDim = blocks; - config.blockDim = threads; auto func = final_scales ? &finalizeMoeRoutingNoFillingKernel : &finalizeMoeRoutingNoFillingKernel; + + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM + = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0); + int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast(num_rows * experts_per_token)); + int32_t const threads = FINALIZE_THREADS_PER_BLOCK; + + config.gridDim = blocks; + config.blockDim = threads; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node, @@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 int64_t num_padding_tokens = 0; #endif - static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); - // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); - int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; - auto fn = [&]() { // IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in @@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 } }(); + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0); + int32_t const blocks + = std::min(smCount * maxBlocksPerSM, static_cast(std::max(expanded_num_tokens, num_padding_tokens))); + int32_t const threads = ACTIVATION_THREADS_PER_BLOCK; + cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu index 59e16ccafd8..462fd5a091e 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu @@ -647,7 +647,9 @@ void run(Data& data, void* stream) // // The upper bound is a strict requirement. The number of blocks should be determined by querying // the device properties, or conservatively low. - static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount(); + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + // WAR: Reserve 8 SMs for overlapping kernels. + int const numBlocksCoop = smCount - 8; // Maximum number of tokens supported by the kernel using a cooperative launch. int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; diff --git a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp index 23aee486baf..54c45031a19 100644 --- a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp +++ b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp @@ -139,6 +139,8 @@ std::tuple> moe_permute(torch::Ten TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32."); int64_t const num_tiles = tile_idx_to_mn_limit.size(0); TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D."); + TORCH_CHECK( + permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32."); int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0); TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles, "max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles."); @@ -253,6 +255,69 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c return output; } +void moe_output_memset_inplace(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit, + torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& permuted_idx_to_expanded_idx, + torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim, int64_t const top_k, + int64_t const ep_size, bool const enable_alltoall = false) +{ + TORCH_CHECK(input.dim() == 2, "input must be 2D."); + int64_t const num_tokens = input.size(0); + int64_t const hidden_size = input.size(1); + TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D."); + TORCH_CHECK( + expanded_idx_to_permuted_idx.scalar_type() == torch::kInt32, "expanded_idx_to_permuted_idx must be int32."); + TORCH_CHECK( + expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens."); + TORCH_CHECK(expanded_idx_to_permuted_idx.size(1) == top_k, "expanded_idx_to_permuted_idx.size(1) must be top_k."); + TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D."); + TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32."); + int64_t const num_tiles = tile_idx_to_mn_limit.size(0); + TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D."); + TORCH_CHECK( + permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32."); + int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0); + TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles, + "max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles."); + TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k, + "max_num_permuted_tokens must be greater than or equal to num_tokens * top_k."); + + TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element."); + TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32."); + + auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device()); + +#define DISPATCH_MOE_OUTPUT_MEMSET(InputType) \ + do \ + { \ + if (!enable_alltoall || ep_size <= top_k) \ + { \ + cudaMemsetAsync(input.data_ptr(), 0x0, sizeof(InputType) * num_tokens * hidden_size, stream); \ + } \ + else \ + { \ + tensorrt_llm::kernels::cute_dsl::moeOutputMemset(static_cast(input.data_ptr()), \ + tile_idx_to_mn_limit.data_ptr(), expanded_idx_to_permuted_idx.data_ptr(), \ + permuted_idx_to_expanded_idx.data_ptr(), num_non_exiting_tiles.data_ptr(), \ + max_num_permuted_tokens, hidden_size, top_k, tile_tokens_dim, stream); \ + } \ + } while (0) + + if (input.scalar_type() == torch::kHalf) + { + DISPATCH_MOE_OUTPUT_MEMSET(half); + } + else if (input.scalar_type() == torch::kBFloat16) + { + DISPATCH_MOE_OUTPUT_MEMSET(__nv_bfloat16); + } + else + { + TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type()); + } + +#undef DISPATCH_MOE_OUTPUT_MEMSET +} + // Activation torch::Tensor moe_swiglu(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit, @@ -421,6 +486,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, " "Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)"); m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor"); + m.def( + "moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, " + "Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k, int " + "ep_size, bool enable_alltoall = False) -> ()"); m.def( "moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, " "int tile_tokens_dim) -> Tensor"); @@ -438,6 +507,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_sort", &torch_ext::moe_sort); m.impl("moe_permute", &torch_ext::moe_permute); m.impl("moe_unpermute", &torch_ext::moe_unpermute); + m.impl("moe_output_memset_inplace", &torch_ext::moe_output_memset_inplace); m.impl("moe_swiglu", &torch_ext::moe_swiglu); m.impl("moe_swiglu_nvfp4_quantize", &torch_ext::moe_swiglu_nvfp4_quantize); m.impl("moe_gelu", &torch_ext::moe_gelu); diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 691ecb45911..feecf3d174b 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -730,7 +730,7 @@ def choose_one( # Log the cache miss. Expect no cache miss in inference. if not is_cache_hit: logger.warning_once( - f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}", + f"[AutoTuner] {custom_op} using the fallback tactic, due to cache miss on input shapes={input_shapes}", key=(custom_op, "warning_autotuning_cache_miss_fallback")) return (best_runner, best_tactic) diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index 39ab46de402..2dc6914bc27 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -77,6 +77,13 @@ def inplace_info(): torch.ops.trtllm.logits_bitmask.default: { 1: "logits" }, + torch.ops.trtllm.moe_output_memset_inplace.default: { + 1: "input" + }, + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default: + { + 6: "output" + }, torch.ops.trtllm.pp_recv_tensors.default: { 1: "tensors" }, diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 842c48725f8..703dcc430a5 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -149,7 +149,7 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: def inputs_pre_hook_finalize_fusion( self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs num_tokens = self.infer_num_tokens(a.size(0)) num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( @@ -184,7 +184,7 @@ def inputs_pre_hook_finalize_fusion( [num_non_exiting_tiles_val], dtype=num_non_exiting_tiles.dtype, device=num_non_exiting_tiles.device) - return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales + return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales class FusedMoEInputsHelper: @@ -268,8 +268,7 @@ def get_valid_tactics( **kwargs, ) -> List[Tuple[int, int]]: # Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103 - sm_version = get_sm_version() - if sm_version not in [100, 103]: + if (sm_version := get_sm_version()) not in (100, 103): logger.debug( f"CuteDSL: SM version {sm_version} is not supported. " f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics." @@ -597,8 +596,7 @@ def cute_dsl_nvfp4_gemm_blackwell( for automatic backend selection with better performance. """ # Validate SM version before attempting to use CuteDSL - sm_version = get_sm_version() - if sm_version not in [100, 103]: + if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. " f"Please use nvfp4_gemm with backend='auto' for automatic backend selection." @@ -660,9 +658,9 @@ def __init__(self, self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size - if get_sm_version() != 100: + if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( - f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" + f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}" ) def unique_id(self): @@ -947,9 +945,9 @@ def __init__(self, self.output_dtype = output_dtype self.scaling_vector_size = scaling_vector_size - if get_sm_version() != 100: + if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( - f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" + f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}" ) def unique_id(self): @@ -1015,11 +1013,12 @@ def get_tuning_config(self) -> TuningConfig: helper.map_to_tuning_buckets), ), constraint_specs=( ConstraintSpec(2, 0, fp4_scale_infer_shape), - ConstraintSpec(5, 0, helper.infer_shape_max_num_tiles), + ConstraintSpec(5, 0, helper.infer_shape_num_tokens), ConstraintSpec(6, 0, helper.infer_shape_max_num_tiles), + ConstraintSpec(7, 0, helper.infer_shape_max_num_tiles), ConstraintSpec( - 7, 0, helper.infer_shape_max_num_permuted_tokens), - ConstraintSpec(9, 0, helper.infer_shape_num_tokens)), + 8, 0, helper.infer_shape_max_num_permuted_tokens), + ConstraintSpec(10, 0, helper.infer_shape_num_tokens)), inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion, use_cuda_graph=True, ) @@ -1027,7 +1026,7 @@ def get_tuning_config(self) -> TuningConfig: def forward(self, inputs: List[torch.Tensor], tactic: Optional[tuple]) -> torch.Tensor: - a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs + a, b, a_sf, b_sf, alpha, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs assert a.dtype == torch.float4_e2m1fn_x2 assert a.dim() == 2 assert b.dtype == torch.float4_e2m1fn_x2 @@ -1051,6 +1050,11 @@ def forward(self, inputs: List[torch.Tensor], assert b_sf.size(2) == scale_k assert alpha.size(0) == l + assert c.dtype == self.output_dtype + assert c.dim() == 2 + num_tokens = c.size(0) + assert c.size(1) == n + num_tiles = m // self.tile_size assert tile_idx_to_group_idx.dtype == torch.int32 assert tile_idx_to_group_idx.size() == (num_tiles, ) @@ -1062,14 +1066,7 @@ def forward(self, inputs: List[torch.Tensor], assert num_non_exiting_tiles.numel() == 1 assert token_final_scales.dtype == torch.float32 assert token_final_scales.dim() == 2 - num_tokens = token_final_scales.size(0) - assert token_final_scales.size(1) == self.top_k - - # TODO: Overlap the memset - c = torch.zeros(num_tokens, - n, - dtype=self.output_dtype, - device=a.device) + assert token_final_scales.size() == (num_tokens, self.top_k) a_ptr = make_ptr(cutlass.Float4E2M1FN, a.data_ptr(), @@ -1183,15 +1180,16 @@ def forward(self, inputs: List[torch.Tensor], return c @torch.library.custom_op( - "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell", - mutates_args=(), + "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell", + mutates_args=("output", ), device_types="cuda") - def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( + def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input: torch.Tensor, weight: torch.Tensor, input_scale: torch.Tensor, weight_scale: torch.Tensor, alpha: torch.Tensor, + output: torch.Tensor, tile_idx_to_group_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, permuted_idx_to_expanded_idx: torch.Tensor, @@ -1204,26 +1202,77 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( tile_size: int, output_dtype: torch.dtype, scaling_vector_size: int = 16, - ) -> torch.Tensor: + ) -> None: tuner = AutoTuner.get() runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner( num_experts, top_k, num_local_experts, local_expert_offset, tile_size, output_dtype, scaling_vector_size) + inputs = [ - input, weight, input_scale, weight_scale, alpha, + input, weight, input_scale, weight_scale, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales ] _, best_tactic = tuner.choose_one( - "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell", + "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell", [runner], runner.get_tuning_config(), inputs, ) - output = runner(inputs, tactic=best_tactic) + runner(inputs, tactic=best_tactic) + + @torch.library.custom_op( + "trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + token_final_scales: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + output_dtype: torch.dtype, + scaling_vector_size: int = 16, + ) -> torch.Tensor: + num_tokens = token_final_scales.size(0) + n = weight.size(1) + output = torch.zeros(num_tokens, + n, + dtype=output_dtype, + device=input.device) + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( + input=input, + weight=weight, + input_scale=input_scale, + weight_scale=weight_scale, + alpha=alpha, + output=output, + tile_idx_to_group_idx=tile_idx_to_group_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + token_final_scales=token_final_scales, + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=local_expert_offset, + tile_size=tile_size, + output_dtype=output_dtype, + scaling_vector_size=scaling_vector_size, + ) return output @torch.library.register_fake( @@ -1275,9 +1324,9 @@ def __init__(self, self.tile_size = tile_size self.scaling_vector_size = scaling_vector_size - if get_sm_version() != 100: + if (sm_version := get_sm_version()) not in (100, 103): raise ValueError( - f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100" + f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}" ) def unique_id(self): diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py index fe20d521e8c..f16c62a4173 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py @@ -2631,6 +2631,7 @@ def wrapper( ): scale_k = k // scaling_vector_size interm_size = n // 2 + scale_interm_size = interm_size // scaling_vector_size num_tiles = m // tile_size a = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2))) b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) @@ -2652,7 +2653,7 @@ def wrapper( c_sf = cute.make_tensor( c_sf_ptr, layout=cute.make_ordered_layout( - (32, 4, interm_size // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) + (32, 4, m // 128, 4, scale_interm_size // 4, 1), order=(2, 1, 4, 0, 3, 5) ), ) alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index dfe1f091857..717d8f78fe2 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -47,7 +47,9 @@ NVLinkOneSided, NVLinkTwoSided, ) +from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_deepgemm import DeepGemmFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE @@ -56,7 +58,7 @@ class ConfigurableMoE(MoE): Configurable MoE layer using composition pattern with automatic configuration This class orchestrates the MoE execution flow by composing: - - moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, WideEPMoE, etc.) + - moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, CuteDslFusedMoE, etc.) Note: Current FusedMoE implementations are used as backends (transitional). Future will have dedicated MoEBackend interface. - Communication: Handles distributed communication (auto-selected) @@ -797,7 +799,7 @@ def backend(self) -> MoE: """ Get the current MoE backend implementation - Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, WideEPMoE) + Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, CuteDslFusedMoE) """ return self._backend @@ -902,27 +904,26 @@ def _get_backend_kwargs( Returns: Dict: Backend-specific keyword arguments """ - backend_name = self.backend.__class__.__name__ kwargs = {} # Common parameters for Cutlass and DeepGemm - if backend_name in ["CutlassFusedMoE", "DeepGemmFusedMoE"]: + if self.backend.__class__ in (CutlassFusedMoE, DeepGemmFusedMoE, CuteDslFusedMoE): pass # Cutlass-specific parameters - if backend_name == "CutlassFusedMoE": + if self.backend.__class__ == CutlassFusedMoE: pass - # WideEP-specific parameters - elif backend_name == "WideEPMoE": - pass + # CuteDSL-specific parameters + elif self.backend.__class__ == CuteDslFusedMoE: + kwargs["enable_alltoall"] = self.enable_alltoall # DeepGemm-specific parameters - elif backend_name == "DeepGemmFusedMoE": + elif self.backend.__class__ == DeepGemmFusedMoE: pass # TRTLLMGen-specific parameters - elif backend_name == "TRTLLMGenFusedMoE": + elif self.backend.__class__ == TRTLLMGenFusedMoE: # Determine router_logits based on whether routing has been done # If backend doesn't support load balancer, routing is done before communication # In that case, router_logits should be None (routing already done) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index d5b22eac420..368ad0c07b8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -20,6 +20,8 @@ from .moe_load_balancer import get_moe_load_balancer from .routing import BaseMoeRoutingMethod +ENABLE_CONFIGURABLE_MOE = os.environ.get("ENABLE_CONFIGURABLE_MOE", "0") == "1" + def get_moe_cls( model_config: ModelConfig, @@ -33,7 +35,16 @@ def get_moe_cls( elif moe_backend.upper() == "VANILLA": return VanillaMoE elif moe_backend.upper() == "CUTEDSL": - return CuteDslFusedMoE + if quant_config is not None and ( + quant_config.quant_mode.has_fp8_block_scales() + or quant_config.quant_mode.has_nvfp4()): + return CuteDslFusedMoE + else: + logger.warning( + "CuteDslFusedMoE only supports fp8_block_scales and nvfp4. " + f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead." + ) + return CutlassFusedMoE elif moe_backend.upper() == "DEEPGEMM": return DeepGemmFusedMoE elif moe_backend.upper() == "TRTLLM": @@ -48,8 +59,8 @@ def get_moe_cls( else: logger.warning( "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. " - f"Check out details in quant_config: {quant_config}" - "Using CutlassFusedMoE instead.") + f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead." + ) return CutlassFusedMoE elif moe_backend.upper() == "WIDEEP": return WideEPMoE @@ -129,8 +140,8 @@ def create_moe_backend( moe_load_balancer = get_moe_load_balancer() if moe_load_balancer is not None: assert moe_cls in [ - WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE - ], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now." + WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE + ], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE and CuteDslFusedMoE now." if bias: assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE @@ -229,6 +240,8 @@ def create_moe_backend( weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, + init_load_balancer=init_load_balancer, + without_comm=without_comm, ) elif moe_cls == DeepGemmFusedMoE: return moe_cls( @@ -331,13 +344,9 @@ def create_moe( moe_cls = get_moe_cls(model_config, override_quant_config) - # Check if ENABLE_CONFIGURABLE_MOE environment variable is set - enable_configurable_moe = os.environ.get('ENABLE_CONFIGURABLE_MOE', - '0') == '1' - - if enable_configurable_moe: - # ConfigurableMoE is only supported for TRTLLMGenFusedMoE backend - if moe_cls == TRTLLMGenFusedMoE: + if ENABLE_CONFIGURABLE_MOE or moe_cls == CuteDslFusedMoE: + # ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends + if moe_cls in (TRTLLMGenFusedMoE, CuteDslFusedMoE): return ConfigurableMoE( routing_method=routing_method, num_experts=num_experts, @@ -358,12 +367,13 @@ def create_moe( else: # Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE requested_backend = model_config.moe_backend.upper() - if requested_backend == "TRTLLM" and moe_cls == CutlassFusedMoE: + if requested_backend in ("TRTLLM", + "CUTEDSL") and moe_cls == CutlassFusedMoE: # Workaround for test cases where TRTLLM backend fallbacks to CutlassFusedMoE due to quant_config incompatibility # Log warning and continue with the fallback backend logger.warning( f"ENABLE_CONFIGURABLE_MOE is set but TRTLLM backend fallback to {moe_cls.__name__} due to quant_config. " - f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend. " + f"ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends. " f"Continuing with legacy MoE backend {moe_cls.__name__}.") else: # For other incompatible backends, raise error diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 5eb32965d81..a087a4c87a8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -8,7 +8,7 @@ from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div +from ...utils import AuxStreamType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE from .interface import AlltoallMethodType from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod @@ -180,6 +180,8 @@ def __init__( VANILLA, apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, + init_load_balancer: bool = True, + without_comm: bool = False, ): super().__init__( @@ -194,6 +196,8 @@ def __init__( weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, + init_load_balancer=init_load_balancer, + without_comm=without_comm, ) def select_alltoall_method_type(self) -> AlltoallMethodType: @@ -206,175 +210,60 @@ def _get_quant_method(self): return NVFP4CuteDslFusedMoEMethod() return super()._get_quant_method() - def forward_chunk_unquantized( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - repeating_info: tuple = (True, True), - ) -> torch.Tensor: - assert not self.has_any_quant - return super().forward_chunk(x, - router_logits, - output_dtype=output_dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - repeating_info=repeating_info) - - def forward_chunk_fp8_block_scales( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - repeating_info: tuple = (True, True), - ) -> torch.Tensor: - assert self.has_deepseek_fp8_block_scales - - # apply routing - token_selected_experts, token_final_scales = self.routing_method.apply( - router_logits) - assert token_selected_experts.shape[ - 1] == self.routing_method.experts_per_token - assert token_selected_experts.shape == token_final_scales.shape - assert token_selected_experts.shape[0] == router_logits.shape[0] - assert token_final_scales.dtype == torch.float32 - assert token_selected_experts.dtype == torch.int32 - - if self.apply_router_weight_on_input: - assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" - assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" - x = x * token_final_scales.to(x.dtype) - # TODO: remove this once we have correct fusedmoe kernel ready - token_final_scales = None - - weight_dtype = self.w3_w1_weight.dtype - - ( - permuted_row_to_unpermuted_row_tensor, - permuted_token_selected_experts_tensor, - permuted_data_tensor, - expert_first_token_offset_tensor, - permuted_token_final_scales_tensor, - unpermuted_row_to_permuted_row_tensor, - ) = torch.ops.trtllm.moe_permute_op( - x, - token_selected_experts, - token_final_scales, - None, # w3_w1_weight.view(weight_dtype), - None, # w2_weight.view(weight_dtype), - None, # quant_scales, - input_sf=None, - num_experts_on_rank=self.expert_size_per_partition, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - cluster_size=self.cluster_size, - cluster_rank=self.cluster_rank, - min_latency_mode=False, - use_fp8_block_scaling=True, - ) - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( - permuted_data_tensor) - h1 = cute_dsl_fp8_group_blockwise_gemm_ref( - a=act_input_fp8, - b=self.w3_w1_weight.view(weight_dtype), - a_sf=act_input_sf, - b_sf=self.quant_scales[0], - offset_array=expert_first_token_offset_tensor, - ) - h2 = swiglu_fused_moe(h1) - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2) - h3 = cute_dsl_fp8_group_blockwise_gemm_ref( - a=act_input_fp8, - b=self.w2_weight.view(weight_dtype), - a_sf=act_input_sf, - b_sf=self.quant_scales[1], - offset_array=expert_first_token_offset_tensor, - ) - h4 = torch.ops.trtllm.moe_finalize_scale_op( - h3, - None, # biases - token_final_scales, - unpermuted_row_to_permuted_row_tensor, - permuted_row_to_unpermuted_row_tensor, - token_selected_experts, - expert_first_token_offset_tensor, - False, # enable_alltoall - x.shape[0], # num_rows - x.shape[1], # (possibly padded) hidden_size - self.unpadded_hidden_size, # original hidden size - self.routing_method.top_k, - self.expert_size_per_partition, # num_experts_per_node - self.tp_size, - self.tp_rank, - self.ep_size, - self.ep_rank, - ) - return h4 - - def forward_chunk_nvfp4( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - repeating_info: tuple = (True, True), - ) -> torch.Tensor: - assert self.has_nvfp4 - - if isinstance(x, Fp4QuantizedTensor): - assert output_dtype is not None - else: - output_dtype = x.dtype - - # apply routing - token_selected_experts, token_final_scales = self.routing_method.apply( - router_logits) - assert token_selected_experts.shape[ - 1] == self.routing_method.experts_per_token - assert token_selected_experts.shape == token_final_scales.shape - assert token_selected_experts.shape[0] == router_logits.shape[0] - assert token_final_scales.dtype == torch.float32 - assert token_selected_experts.dtype == torch.int32 - - run_post_quant_allgather = self.use_dp and self.parallel_size > 1 - if run_post_quant_allgather: + def quantize_input(self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + post_quant_comm: bool = True): + """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. + + Args: + x: Input tensor to quantize + post_quant_comm: + If True, quantize for post-quant communication path. + If False, quantize for non-communication path + + Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed + + For quantization methods that produce scaling factors: + - x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)] + - The 2D shape is required for proper handling in alltoall/allgather operations + - scaling_vector_size is typically the group size for block-wise quantization + """ + x_sf = None + if self.has_nvfp4: if isinstance(x, Fp4QuantizedTensor): assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" + x_row = x.shape[0] x, x_sf = x.fp4_tensor, x.scaling_factor else: + x_row = x.shape[0] x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, False, False) - # note: we use uint8 to store 2 fp4 values - x_row, x_col = x.size(0), x.size(1) * 2 + elif self.has_deepseek_fp8_block_scales: + # FP8 block scales doesn't support permutation of quantized inputs. + # WAR: The quantization is in run_moe_fp8_block_scales. + pass else: - if not isinstance(x, Fp4QuantizedTensor): - x, x_sf = torch.ops.trtllm.fp4_quantize( - x, self.fc31_input_scale, self.scaling_vector_size, False, - False) + raise ValueError( + f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}." + ) - if run_post_quant_allgather: - # Original allgather logic - if x_sf is not None: - x_sf = x_sf.view(x_row, ceil_div(x_col, - self.scaling_vector_size)) - assert x_sf.dim( - ) == 2, "The hidden states scaling factor should be 2D tensor before allgather" - - x, x_sf, token_selected_experts, token_final_scales = allgather( - [x, x_sf, token_selected_experts, token_final_scales], - self.mapping, - dim=0, - sizes=None if use_dp_padding else all_rank_num_tokens) + if x_sf is not None: + x_sf = x_sf.view(x_row, -1) + return x, x_sf + def run_moe_nvfp4( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + ) -> torch.Tensor: + assert self.has_nvfp4 + output_dtype = torch.bfloat16 tile_size = 128 + tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort( token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, @@ -409,13 +298,28 @@ def forward_chunk_nvfp4( tile_size=tile_size, ) if self.use_fused_finalize: - x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell( + output = torch.empty((token_final_scales.size(0), self.hidden_size), + dtype=output_dtype, + device=x.device) + torch.ops.trtllm.moe_output_memset_inplace( + input=output, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + tile_tokens_dim=tile_size, + top_k=self.routing_method.experts_per_token, + ep_size=self.mapping.moe_ep_size, + enable_alltoall=enable_alltoall, + ) + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=self.w2_weight.view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), weight_scale=self.quant_scales.fc2_weight_block.view( torch.uint8), alpha=self.quant_scales.fc2_global, + output=output, tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, @@ -428,6 +332,7 @@ def forward_chunk_nvfp4( tile_size=tile_size, output_dtype=output_dtype, ) + x = output else: x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell( input=x.view(torch.float4_e2m1fn_x2), @@ -452,6 +357,127 @@ def forward_chunk_nvfp4( ) return x + def run_moe_fp8_block_scales( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + ) -> torch.Tensor: + assert self.has_deepseek_fp8_block_scales + assert x_sf is None + weight_dtype = self.w3_w1_weight.dtype + + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_experts, + token_final_scales, + None, # w3_w1_weight.view(weight_dtype), + None, # w2_weight.view(weight_dtype), + None, # quant_scales, + input_sf=None, + num_experts_on_rank=self.expert_size_per_partition, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + cluster_size=self.cluster_size, + cluster_rank=self.cluster_rank, + min_latency_mode=False, + use_fp8_block_scaling=True, + ) + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + permuted_data_tensor) + h1 = cute_dsl_fp8_group_blockwise_gemm_ref( + a=act_input_fp8, + b=self.w3_w1_weight.view(weight_dtype), + a_sf=act_input_sf, + b_sf=self.quant_scales[0], + offset_array=expert_first_token_offset_tensor, + ) + h2 = swiglu_fused_moe(h1) + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2) + h3 = cute_dsl_fp8_group_blockwise_gemm_ref( + a=act_input_fp8, + b=self.w2_weight.view(weight_dtype), + a_sf=act_input_sf, + b_sf=self.quant_scales[1], + offset_array=expert_first_token_offset_tensor, + ) + h4 = torch.ops.trtllm.moe_finalize_scale_op( + h3, + None, # biases + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_experts, + expert_first_token_offset_tensor, + enable_alltoall, + x.shape[0], # num_rows + x.shape[1], # (possibly padded) hidden_size + self.unpadded_hidden_size, # original hidden size + self.routing_method.top_k, + self.expert_size_per_partition, # num_experts_per_node + self.tp_size, + self.tp_rank, + self.ep_size, + self.ep_rank, + ) + return h4 + + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + ) -> torch.Tensor: + """ + Run MoE computation with CuteDSL backend. + + This method encapsulates the core MoE computation logic, handling different + quantization schemes (fp8_block_scales and nvfp4). + + Args: + # Standard MoE interface parameters: + x: Input hidden states (may be pre-quantized) + token_selected_experts: Expert IDs [num_tokens, top_k]. If EPLB is enabled, + this represents expert slots [num_tokens, top_k] instead. + token_final_scales: Final scaling factors for each token + x_sf: Input scale factors (optional, for certain quantization schemes) + enable_alltoall: Whether alltoall communication is enabled. + + Returns: + final_hidden_states tensor. + """ + if self.has_nvfp4: + return self.run_moe_nvfp4( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=enable_alltoall) + elif self.has_deepseek_fp8_block_scales: + return self.run_moe_fp8_block_scales( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=enable_alltoall) + else: + raise ValueError( + f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}." + ) + def forward_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], @@ -461,32 +487,30 @@ def forward_chunk( use_dp_padding: Optional[bool] = None, repeating_info: tuple = (True, True), ) -> torch.Tensor: - if self.has_any_quant: - if self.has_nvfp4: - return self.forward_chunk_nvfp4( - x, - router_logits, - output_dtype=output_dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - repeating_info=repeating_info) - elif self.has_deepseek_fp8_block_scales: - return self.forward_chunk_fp8_block_scales( - x, - router_logits, - output_dtype=output_dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - repeating_info=repeating_info) - else: - raise ValueError( - f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" - ) - else: - return self.forward_chunk_unquantized( - x, - router_logits, - output_dtype=output_dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - repeating_info=repeating_info) + # Currently, the default path is that ConfigurableMoE calls CuteDslFusedMoE.run_moe. + # This forward_chunk method is a reference implementation of the legacy path. + # Apply routing + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + assert token_selected_experts.shape[ + 1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + x, x_sf = self.quantize_input(x) + + if self.use_dp and self.parallel_size > 1: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + + x = self.run_moe(x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + enable_alltoall=False) + return x diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index bb883ffc953..c300243dff6 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -76,6 +76,7 @@ def __init__( swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, init_load_balancer: bool = True, + without_comm: bool = False, activation_type: ActivationType = ActivationType.Swiglu, ): @@ -138,49 +139,58 @@ def __init__( self.has_been_profiled = False self.has_been_profiled_min_latency = False - # TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future. - self.alltoall_method_type = self.select_alltoall_method_type() - logger.info_once( - f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}", - key="alltoall_method_type") - self.alltoall_workspace = None - self.alltoall_prepare_workspace = None - self.use_low_precision_combine = False - if self.enable_alltoall: - self.use_low_precision_combine = model_config.use_low_precision_moe_combine - - if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: - MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) - elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: - # Calculate required workspace size - ep_size = self.mapping.moe_ep_size - max_num_tokens = model_config.max_num_tokens - hidden_size = self.hidden_size - dtype = self.dtype or torch.float16 - - workspace_size = MoeAlltoAll.calculate_required_workspace_size( - ep_size, self.routing_method.experts_per_token, - max_num_tokens, hidden_size, dtype) - - self.moe_a2a = MoeAlltoAll( - mapping=self.mapping, - max_num_tokens=model_config.max_num_tokens, - top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, - workspace_size_per_rank=workspace_size, - ) - elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - raise NotImplementedError( - "DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet" - ) - else: - raise NotImplementedError( - f"Unsupported alltoall method type: {self.alltoall_method_type!r}" - ) + # When without_comm=True, skip communication initialization (ConfigurableMoE will handle it) + if not without_comm: + self.alltoall_method_type = self.select_alltoall_method_type() + logger.info_once( + f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}", + key="alltoall_method_type") + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + self.use_low_precision_combine = False + if self.enable_alltoall: + self.use_low_precision_combine = model_config.use_low_precision_moe_combine + + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + # Calculate required workspace size + ep_size = self.mapping.moe_ep_size + max_num_tokens = model_config.max_num_tokens + hidden_size = self.hidden_size + dtype = self.dtype or torch.float16 + + workspace_size = MoeAlltoAll.calculate_required_workspace_size( + ep_size, self.routing_method.experts_per_token, + max_num_tokens, hidden_size, dtype) + + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=model_config.max_num_tokens, + top_k=self.routing_method.experts_per_token, + num_experts=self.num_slots, + workspace_size_per_rank=workspace_size, + ) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + raise NotImplementedError( + "DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet" + ) + else: + raise NotImplementedError( + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" + ) + else: + # When without_comm=True, set minimal attributes + # Communication will be handled by parent wrapper (e.g., ConfigurableMoE) + self.alltoall_method_type = AlltoallMethodType.NotEnabled + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + self.use_low_precision_combine = False + self.moe_a2a = None # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index f2bcc4397ac..d1253a838fd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -156,14 +156,14 @@ def __init__( raise NotImplementedError( f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) - else: - # When without_comm=True, set minimal attributes - # Communication will be handled by parent wrapper (e.g., ConfigurableMoE) - self.alltoall_method_type = AlltoallMethodType.NotEnabled - self.alltoall_workspace = None - self.alltoall_prepare_workspace = None - self.use_low_precision_combine = False - self.moe_a2a = None + else: + # When without_comm=True, set minimal attributes + # Communication will be handled by parent wrapper (e.g., ConfigurableMoE) + self.alltoall_method_type = AlltoallMethodType.NotEnabled + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + self.use_low_precision_combine = False + self.moe_a2a = None self._weights_created = False if not model_config.skip_create_weights_in_init: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index cb61b7867f3..7ea7e95539a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -2085,34 +2085,44 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod): - def post_load_weights(self, module: torch.nn.Module): - super().post_load_weights(module) + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + super().load_expert_w3_w1_weight(module, w1_weight, w3_weight, + dst_w3_w1_weight) - # Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion. - w3_w1_weight = module.w3_w1_weight.data.view(float4_e2m1x2) - m = w3_w1_weight.size(1) - n = w3_w1_weight.size(2) * 2 + # Interleave FC1 weight for GEMM1 + SwiGLU fusion. + w3_w1_weight = dst_w3_w1_weight.cuda().view(float4_e2m1x2) w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight, group_size=64, - dim=1) + dim=0) w3_w1_weight_interleaved = w3_w1_weight_interleaved.view( - module.w3_w1_weight.data.dtype) - module.w3_w1_weight.data.copy_(w3_w1_weight_interleaved) + dst_w3_w1_weight.dtype) + dst_w3_w1_weight.copy_(w3_w1_weight_interleaved) - w3_w1_weight_scale = module.quant_scales.fc1_weight_block.data.view( - float4_sf_dtype) + def load_expert_w3_w1_weight_scale_nvfp4( + self, module: torch.nn.Module, w1_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + dst_w3_w1_weight_scale: torch.Tensor): + super().load_expert_w3_w1_weight_scale_nvfp4(module, w1_weight_scale, + w3_weight_scale, + dst_w3_w1_weight_scale) + + # Interleave FC1 scales for GEMM1 + SwiGLU fusion. + n = module.intermediate_size_per_partition * 2 + k = module.hidden_size + w3_w1_weight_scale = dst_w3_w1_weight_scale.cuda().view(float4_sf_dtype) w3_w1_weight_scale_unswizzled = unswizzle_sf( - w3_w1_weight_scale, m, n).view(-1, m, - n // module.scaling_vector_size) + w3_w1_weight_scale, n, k).view(n, k // module.scaling_vector_size) w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate( - w3_w1_weight_scale_unswizzled, group_size=64, dim=1) + w3_w1_weight_scale_unswizzled, group_size=64, dim=0) w3_w1_weight_scale_interleaved = swizzle_sf( - w3_w1_weight_scale_unswizzled_interleaved, m, - n).view(-1, m, n // module.scaling_vector_size) + w3_w1_weight_scale_unswizzled_interleaved, n, + k).view(n, k // module.scaling_vector_size) w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved.view( - module.quant_scales.fc1_weight_block.data.dtype) - module.quant_scales.fc1_weight_block.data.copy_( - w3_w1_weight_scale_interleaved) + dst_w3_w1_weight_scale.dtype) + dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved) class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod): diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 0768a777e0a..f26d7bf81e7 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -112,13 +112,13 @@ def forward(self, mutates_args=()) def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: runners = [GemmRunner()] - tunner = AutoTuner.get() + tuner = AutoTuner.get() tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( input_idx=0, dim_idx=0, gen_tuning_buckets=get_power_of_2_num_tokens_buckets, map_to_tuning_buckets=next_positive_power_of_2), ), ) - runner, tactic = tunner.choose_one( + runner, tactic = tuner.choose_one( "autotuner_test::get_best_gemm_tactic", runners, tuning_config, @@ -175,20 +175,20 @@ def forward(self, x, w = torch.randn(M, 64), torch.randn(64, 128) runners = [PartialCrashedRunner()] - tunner = AutoTuner.get() + tuner = AutoTuner.get() tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( input_idx=0, dim_idx=0, gen_tuning_buckets=get_power_of_2_num_tokens_buckets, map_to_tuning_buckets=next_positive_power_of_2), ), ) with autotune(): - runner, tactic = tunner.choose_one("test_autotuner_try_block", runners, - tuning_config, [x, w]) + runner, tactic = tuner.choose_one("test_autotuner_try_block", runners, + tuning_config, [x, w]) m = M // 2 while m >= 1: - _, tactic = tunner.choose_one("test_autotuner_try_block", runners, - tuning_config, [torch.randn(m, 64), w]) + _, tactic = tuner.choose_one("test_autotuner_try_block", runners, + tuning_config, [torch.randn(m, 64), w]) assert tactic in [ -1, 0 ], f"Expect only tactic -1, 0 being chosen, but got tactic {tactic}." diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index cea8ff57e29..4ba09fa79c1 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -1364,9 +1364,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend): if dtype == torch.float16: pytest.skip( "CUTEDSL NVFP4 MoE backend does not support float16 yet") - if get_sm_version() != 100: + if get_sm_version() not in (100, 103): pytest.skip( - "CUTEDSL NVFP4 MoE backend is only supported on SM 100 GPUs") + "CUTEDSL NVFP4 MoE backend supports SM 100 (B200) and SM 103 (B300) only" + ) test_all_kernels = True if get_sm_version() == 120: diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 3fa0aa20285..4faec5d6f13 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -227,6 +227,60 @@ def test_moe_unpermute(dtype: str, num_tokens: int, top_k: int, tile_size: int): torch.testing.assert_close(x, x_ref) +@pytest.mark.parametrize("tile_size", [128, 256]) +@pytest.mark.parametrize("ep_size", [1, 8, 32]) +@pytest.mark.parametrize("top_k", [1, 2, 8]) +@pytest.mark.parametrize("num_tokens", [128, 515, 1024]) +@pytest.mark.parametrize("dtype", ["bfloat16", "float16"]) +def test_moe_output_memset_inplace( + dtype: str, num_tokens: int, top_k: int, ep_size: int, tile_size: int +): + dtype = getattr(torch, dtype) + hidden_size = 4096 + num_experts = 256 + num_local_experts = num_experts // ep_size + enable_alltoall = True + + routing_logits = torch.randn(num_tokens, num_experts, device="cuda") + token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1) + token_selected_experts = token_selected_experts.to(torch.int32) + token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32) + + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=num_experts, + top_k=top_k, + local_expert_offset=0, + local_num_experts=num_local_experts, + tile_tokens_dim=tile_size, + ) + + x = torch.ones(num_tokens, hidden_size, dtype=dtype, device="cuda") + torch.ops.trtllm.moe_output_memset_inplace( + x, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + tile_size, + top_k, + ep_size, + enable_alltoall=enable_alltoall, + ) + x_ref = torch.zeros_like(x) + if enable_alltoall and ep_size > top_k: + x_ref[(expanded_idx_to_permuted_idx < 0).all(dim=-1)] = 1 + torch.testing.assert_close(x, x_ref) + + @pytest.mark.parametrize("tile_size", [128, 256]) @pytest.mark.parametrize("top_k", [1, 2, 8]) @pytest.mark.parametrize("num_tokens", [128, 515, 1024]) @@ -257,7 +311,10 @@ def test_moe_swiglu(dtype: str, num_tokens: int, top_k: int, tile_size: int): torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens]) -@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs") +@pytest.mark.skipif( + get_sm_version() not in (100, 103), + reason="This test is only supported on SM 100 and SM 103 GPUs", +) @pytest.mark.parametrize("tile_size", [128, 256]) @pytest.mark.parametrize("top_k", [1, 2, 8]) @pytest.mark.parametrize("num_tokens", [128, 515, 1024]) @@ -332,7 +389,10 @@ def test_moe_gelu(dtype: str, num_tokens: int, top_k: int, tile_size: int): torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens]) -@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs") +@pytest.mark.skipif( + get_sm_version() not in (100, 103), + reason="This test is only supported on SM 100 and SM 103 GPUs", +) @pytest.mark.parametrize("tile_size", [128]) @pytest.mark.parametrize("ep_size", [1, 8, 32]) @pytest.mark.parametrize("top_k", [1, 2, 8]) @@ -425,7 +485,10 @@ def test_nvfp4_grouped_gemm_blackwell(num_tokens: int, top_k: int, ep_size: int, torch.testing.assert_close(c[:num_valid_permuted_tokens], c_ref[:num_valid_permuted_tokens]) -@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs") +@pytest.mark.skipif( + get_sm_version() not in (100, 103), + reason="This test is only supported on SM 100 and SM 103 GPUs", +) @pytest.mark.parametrize("tile_size", [128]) @pytest.mark.parametrize("ep_size", [1, 8, 32]) @pytest.mark.parametrize("top_k", [1, 2, 8]) @@ -523,7 +586,10 @@ def test_nvfp4_grouped_gemm_finalize_blackwell( assert match_ratio > 0.99 -@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs") +@pytest.mark.skipif( + get_sm_version() not in (100, 103), + reason="This test is only supported on SM 100 and SM 103 GPUs", +) @pytest.mark.parametrize("tile_size", [128]) @pytest.mark.parametrize("ep_size", [1, 8, 32]) @pytest.mark.parametrize("top_k", [1, 2, 8])