From 110503f33a7dbf1ef67622d9f6351dfe59d86814 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 15 Aug 2025 10:42:03 +1200 Subject: [PATCH 01/13] perf: Make finalize fusion part of the tactic selection logic Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../mixtureOfExpertsBackendBenchmarkFixture.h | 21 +- ...ixtureOfExpertsBackendBenchmarkLauncher.cu | 198 ++---------------- .../include/cutlass_extensions/gemm_configs.h | 4 +- .../cutlass_kernels/include/moe_kernels.h | 50 +++-- .../moe_gemm/moe_gemm_template_dispatch.h | 8 +- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 8 +- .../include/moe_gemm_kernels.h | 8 +- .../mixtureOfExpertsPlugin.cpp | 6 +- cpp/tensorrt_llm/thop/moeOp.cpp | 27 ++- .../kernels/mixtureOfExpertsTest.cu | 47 +++-- tensorrt_llm/_torch/autotuner.py | 3 +- .../_torch/custom_ops/torch_custom_ops.py | 3 +- 12 files changed, 140 insertions(+), 243 deletions(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 2559ae54840..36cbe76544a 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -833,7 +833,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture // Runs for 3 iterations or 1 second and picks the best option int pickBestTactic(MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile) { - auto tactics = mMoERunner.getTactics(); + auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile)); ::nvtx3::scoped_range nvtx(tensorrt_llm::common::nvtx::nextColor(), "Tactic Profiling GEMM " + std::to_string(static_cast(gemm_to_profile))); // We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we @@ -925,12 +925,14 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture std::pair setTactic( int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile) { - auto tactics = mMoERunner.getTactics(); + auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1); + auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2); std::vector, GemmToProfile>> tactics_to_profile{ {tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}}; for (auto& combo : tactics_to_profile) { auto& t = combo.first.get(); + auto& tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2; if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER) { t = 0; // Unneeded tactic, set to 0 @@ -947,7 +949,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture } } - mMoERunner.setTactic(tactics[tactic_idx1], tactics[tactic_idx2]); + mMoERunner.setTactic(tactics1[tactic_idx1], tactics2[tactic_idx2]); mBestTacticGemm1 = tactic_idx1; mBestTacticGemm2 = tactic_idx2; return {tactic_idx1, tactic_idx2}; @@ -965,7 +967,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture auto expert_weights_size = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size; - auto tactics = mMoERunner.getTactics()[tactic_idx]; + auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile))[tactic_idx]; if (static_cast(gemm_to_profile) != static_cast(mGemmProfilerBackend.mGemmToProfile)) { throw std::runtime_error("Configuration mismatch between mGemmProfilerBackend and runMoEPermute"); @@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark::runBenchmark(benchmark::State& state } if (LOG_LEVEL >= INFO) { - auto tactics = mMoERunner.getTactics(); - std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics.size() << "\n" - << tactics[tactic_idx1].toString() << std::endl; - std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics.size() << "\n" - << tactics[tactic_idx2].toString() << std::endl; + auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1); + auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2); + std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n" + << tactics1[tactic_idx1].toString() << std::endl; + std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n" + << tactics2[tactic_idx2].toString() << std::endl; } state.counters["tactic_idx1"] = tactic_idx1; state.counters["tactic_idx2"] = tactic_idx2; diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu index b784c6d0bc4..31c7cc84e3c 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu @@ -42,144 +42,11 @@ struct WeightParams ->Apply(argGen>>) template -auto listAllTactics() +auto listAllTactics(MoeGemmId gemm_id) { int const sm = getSMVersion(); using RunnerType = decltype(BenchClass::mMoERunner); - return RunnerType::getTactics(sm); -} - -template -int parseTacticToId(nlohmann::json tactic_config) -{ - bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get(); - int tile_shape_id = -1; - std::array tile_shape; - if (tactic_config.at("tile_shape").is_array()) - tactic_config.at("tile_shape").get_to(tile_shape); - else - tile_shape_id = tactic_config.at("tile_shape").get(); - - std::vector confs = listAllTactics(); - - try - { - for (int i = 0; i < confs.size(); i++) - { - auto const& c = confs[i]; - if (c.is_tma_warp_specialized != is_tma_warp_specialized) - continue; - - if (!is_tma_warp_specialized) - { - int stages = tactic_config.at("stages").get(); - if (c.stages != stages) - continue; - } - - if (tile_shape_id != -1) - { - int comp = c.getTileConfigAsInt(); - if (tile_shape_id != comp) - continue; - if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get()) - continue; - - // Found matching config - return i; - } - - // Handle if the user provided a shape instead of the enum value - if (is_tma_warp_specialized) - { - // TODO Add cases for blackwell shapes - using Kv = uint64_t; - constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); }; - static std::unordered_map const tile_map{ - {K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B}, - {K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B}, - {K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B}, - {K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B}, - {K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B}, - - {K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B}, - {K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B}, - {K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B}, - {K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B}, - {K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B}, - {K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B}, - }; - - if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1]))) - continue; - - static std::unordered_map const cluster_map{ - // CTA configs for M=64 - {K(1, 1), ClusterShape::ClusterShape_1x1x1}, - {K(2, 1), ClusterShape::ClusterShape_2x1x1}, - {K(1, 2), ClusterShape::ClusterShape_1x2x1}, - {K(2, 2), ClusterShape::ClusterShape_2x2x1}, - }; - - std::array cluster_shape; - tactic_config.at("cluster_shape").get_to(cluster_shape); - - if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1]))) - continue; - - // Found matching config - return i; - } - else - { - std::array warp_shape; - tactic_config.at("warp_shape").get_to(warp_shape); - - using Kv = uint64_t; - constexpr static auto K = [](std::array a, std::array b) - { - uint64_t sum = 0; - for (auto v : a) - sum = sum * 512 + v; - for (auto v : b) - sum = sum * 256 + v; - return sum; - }; - static std::unordered_map tile_map{ - {K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}, - - {K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64}, - {K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64}, - - {K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}, - {K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64}, - {K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}, - - {K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64}, - {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}, - {K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64}, - {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}, - {K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}, - - {K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}, - - {K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64} - - }; - if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape))) - continue; - - // Found matching config - return i; - } - } - } - catch (std::out_of_range const& e) - { - std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl; - } - - return -1; + return RunnerType::getTactics(sm, gemm_id); } template @@ -196,10 +63,6 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) parseTacticToVectorID(c, tactic_ids); } } - else if (tactic.is_object()) - { - tactic_ids.push_back(parseTacticToId(tactic)); - } else if (tactic.is_string()) { assert(tactic.is_string()); @@ -415,20 +278,11 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) std::vector tactic_ids2{}; if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2")) { - if (run_config.contains("tactic_id")) - { - throw std::invalid_argument("Cannot use tactic_id and tactic_idX"); - } has_tactic_ids2 = true; - parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1); - parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2); - } - else - { - parseTacticToVectorID(run_config["tactic_id"], tactic_ids1); - has_tactic_ids2 = false; - tactic_ids2.resize(1); // Dummy value so we loop exactly once below + parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1); + parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2); } + if (tactic_ids1.empty() || tactic_ids2.empty()) { std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; @@ -531,7 +385,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark) // {ActivationType::Relu, ActivationType::Gelu, // ActivationType::Silu, ActivationType::Geglu, // ActivationType::Swiglu}; - auto cutlass_tactic = {-1}; // {0,..., listAllTactics().size()}; + auto cutlass_tactic = {-1}; // {0,..., listAllTactics(MoeGemmId).size()}; auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2}; for (auto num_expert : num_experts) @@ -558,14 +412,18 @@ void argGen(benchmark::internal::Benchmark* benchmark) { if (LOG_LEVEL >= VERBOSE) { - std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n"; - int i = 0; - for (auto& t : listAllTactics()) + std::cout << "== List of all tactics for dtype " << (int) BenchClass::toDTypeID() << " ==\n"; + for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) { - std::cout << "Tactic " << i << ":\n"; - std::cout << t.toString() << std::endl; + int i = 0; + std::cout << "=== GEMM " << (int) gemm_id << " ===\n"; + for (auto& t : listAllTactics(gemm_id)) + { + std::cout << "==== Tactic " << i << " ====\n"; + std::cout << t.toString() << std::endl; - i++; + i++; + } } } @@ -652,7 +510,6 @@ void help() " \"bias\": int, (optional)\n" " \"do_final_scale\": int, (optional)\n" " \"act_fn\": int,\n" - " \"tactic_id\": tactic, (see below)\n" " \"tactic_id1\": tactic, (see below)\n" " \"tactic_id2\": tactic, (see below)\n" " \"dtypes\": [string, ...], (optional)\n" @@ -676,27 +533,14 @@ void help() "- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n" "- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = " "swiglu\n" - "- \"tactic_id, tactic_id1, tactic_id2\"\n" - "The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in " - "auto mode)\n" - "Use tactic_idX to set the tactic for the corresponding GEMM" + "- \"tactic_id1, tactic_id2\"\n" + "The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM" "Valid tactics are:\n" - " - An object:\n" - " {\n" - " \"is_tma_warp_specialized\": bool,\n" - " \"tile_shape\": [int, int, int] or int,\n" - " \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape " - "is " - "an int)\n" - " \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n" - " \"stages\": int, (required for non-sm90)\n" - " },\n" - " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test " - "configurations\n" + " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types " + "or GPU architectures\n" " - An array: of integers or objects, forms a list of tactics to sweep\n" " - The string \"all\": This will sweep through all possible tactics\n" - " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark " - "case. " + " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. " "Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate " "results" "- dtypes - A list of dtypes to run this config through.\n" diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index fe75687e368..80fc8c02cec 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -432,6 +432,7 @@ struct CutlassGemmConfig bool enableCudaKernel = false; int sm_version = 80; // Use 80 as a catch all for <90 bool is_tma_warp_specialized = false; + bool is_finalize_fusion = false; CutlassGemmConfig() = default; @@ -502,7 +503,8 @@ struct CutlassGemmConfig << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() << "\n\tcluster shape ID: " << (int) cluster_shape << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule - << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false") + << "\n\tis_finalize_fusion: " << (is_finalize_fusion ? "true" : "false"); } else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 7d592bed0e4..0d0bbd1c068 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -228,6 +228,13 @@ struct MOEParallelismConfig } }; +enum class MoeGemmId : int +{ + Undefined = 0, + GEMM_1, + GEMM_2 +}; + struct QuantParams { // Int weight only quantization params @@ -446,7 +453,7 @@ class CutlassMoeFCRunnerInterface virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; - virtual std::vector getTactics() = 0; + virtual std::vector getTactics(MoeGemmId gemm_id) = 0; virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, @@ -593,15 +600,34 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface gemm2_config_ = std::move(gemm2_config); } - std::vector getTactics() override + static auto& addFinalizeFusionConfigs( + std::vector& configs, bool use_fused_finalize) { - return moe_gemm_runner_.getConfigs(); + if (!use_fused_finalize) + return configs; + + size_t const num_configs = configs.size(); + for (size_t i = 0; i < num_configs; ++i) + { + if (configs[i].is_tma_warp_specialized) + { + configs.push_back(configs[i]); + configs.back().is_finalize_fusion = true; + } + } + return configs; } - static std::vector getTactics(int sm) + std::vector getTactics(MoeGemmId gemm_id) override + { + return addFinalizeFusionConfigs( + moe_gemm_runner_.getConfigs(), gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); + } + + static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return RunnerType::getConfigs(sm); + return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused(sm)); } void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, @@ -798,6 +824,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface && !use_w4_groupwise; } + static bool mayHaveFinalizeFused(int sm) + { + using RunnerType = decltype(moe_gemm_runner_); + return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise; + } + // TODO: This should eventually take the quant params to give more flexibility static auto getScalingType() { @@ -895,12 +927,7 @@ struct GemmProfilerBackend { public: using Config = cutlass_extensions::CutlassGemmConfig; - enum class GemmToProfile - { - Undefined = 0, - GEMM_1, - GEMM_2 - }; + using GemmToProfile = MoeGemmId; void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype, nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size, @@ -951,7 +978,6 @@ struct GemmProfilerBackend CutlassMoeFCRunnerInterface* mInterface; GemmToProfile mGemmToProfile = GemmToProfile::Undefined; - std::vector mAllTacticsSaved; int mSM{}; int64_t mNumExperts{}; int64_t mNumExpertsPerNode{}; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 56a8299f18f..a6238883499 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -566,11 +566,11 @@ bool MoeGemmRunner::isTmaWarpSpecializ } template -bool MoeGemmRunner::supportsTmaWarpSpecialized() const +bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { - return (sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - || (sm_ >= 100 && sm_ < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) - || ((sm_ == 120 || sm_ == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation()); + return (sm == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + || (sm >= 100 && sm < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) + || ((sm == 120 || sm == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } template diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 730840717c2..c0fa39385da 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2847,9 +2847,8 @@ void CutlassMoeFCRunnerusing_fused_finalize ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe @@ -4007,7 +4006,9 @@ CutlassMoeFCRunner:: bool apply_bias = parallelism_config.tp_rank == 0; auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; bool using_fused_finalize - = use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora; + = use_fused_finalize_ && gemm2_config_->is_finalize_fusion && !use_w4_groupwise && !use_lora; + TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_config_->using_fused_finalize, + "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it"); if (using_fused_finalize) { assert(min_latency_mode == false); @@ -4652,7 +4653,6 @@ void GemmProfilerBackend::prepareTmaWsInputs( void GemmProfilerBackend::prepare( int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream) { - mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; auto workspace_size = getWorkspaceSize(num_tokens); diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h index 3a72417a216..06ab4047ad2 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h @@ -290,7 +290,13 @@ class MoeGemmRunner static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsTmaWarpSpecialized() const; + + [[nodiscard]] bool supportsTmaWarpSpecialized() const + { + return supportsTmaWarpSpecialized(sm_); + } + + [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 189e23b8acb..59d92e64290 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -946,8 +946,8 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, std::optional gemm2; if (common::getEnvForceDeterministicMOE()) { - gemm1 = mMOERunner->getTactics()[0]; - gemm2 = mMOERunner->getTactics()[0]; + gemm1 = mMOERunner->getTactics(MoeGemmId::GEMM_1)[0]; + gemm2 = mMOERunner->getTactics(MoeGemmId::GEMM_2)[0]; } else { @@ -1278,7 +1278,7 @@ void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExper auto MixtureOfExpertsGemmProfiler::getTactics(int m, int n, int k) const -> std::vector { assert(mRunner); - return mRunner->mMOERunner->getTactics(); + return mRunner->mMOERunner->getTactics(backend.mGemmToProfile); } void MixtureOfExpertsGemmProfiler::initTmpData( diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 328cce3d014..91cbb9d8c34 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -215,7 +215,8 @@ class FusedMoeRunner : public torch::CustomClassHolder mKernelRunner->use_fused_finalize_ = mUseFusedFinalize; mProfiler = std::make_shared(); - mAllProfiles = mKernelRunner->getTactics(); + mGemm1Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_1); + mGemm2Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_2); } ~FusedMoeRunner() @@ -585,10 +586,10 @@ class FusedMoeRunner : public torch::CustomClassHolder return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids); } - int64_t getTacticNum() + int64_t getTacticNum(int gemm_idx) { std::lock_guard lock(mMutex); - return mAllProfiles.size(); + return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size(); } // TODO Update this to be able to tell if we are profiling swiglu bias @@ -624,10 +625,14 @@ class FusedMoeRunner : public torch::CustomClassHolder : group_size_; int const num_experts = static_cast(fc2_expert_weights.sizes()[0] * ep_size); + auto const gemm_to_profile + = (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2; + auto const& profiles = (gemm_idx == 1) ? mGemm1Profiles : mGemm2Profiles; + // Get specific profile configs according to the profile_id. // Fallback tactic is set to be 0 // TODO: use the best tactic id found offline for a better default inference perf - auto const& profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id]; + auto const& profile = profile_id == -1 ? profiles.front() : profiles[profile_id]; auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); @@ -638,8 +643,7 @@ class FusedMoeRunner : public torch::CustomClassHolder if (do_preparation) { // Set profiled gemm idx - mProfiler->mGemmToProfile - = (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2; + mProfiler->mGemmToProfile = gemm_to_profile; // mProfiler init auto parallelism_config = kernels::MOEParallelismConfig(static_cast(tp_size), @@ -704,7 +708,8 @@ class FusedMoeRunner : public torch::CustomClassHolder bool mUseFusedFinalize = true; using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - std::vector mAllProfiles; + std::vector mGemm1Profiles; + std::vector mGemm2Profiles; void freeProfileWorkspace() { @@ -730,15 +735,15 @@ class FusedMoeRunner : public torch::CustomClassHolder return; } - auto best_gemm1_profile = mAllProfiles.front(); - auto best_gemm2_profile = mAllProfiles.front(); + auto best_gemm1_profile = mGemm1Profiles.front(); + auto best_gemm2_profile = mGemm2Profiles.front(); if (profile_ids.has_value()) { TORCH_CHECK(profile_ids.value().size() == 2, "Expecting 2 profile ids"); best_gemm1_profile - = profile_ids.value()[0] == -1 ? best_gemm1_profile : mAllProfiles.at(profile_ids.value()[0]); + = profile_ids.value()[0] == -1 ? best_gemm1_profile : mGemm1Profiles.at(profile_ids.value()[0]); best_gemm2_profile - = profile_ids.value()[1] == -1 ? best_gemm2_profile : mAllProfiles.at(profile_ids.value()[1]); + = profile_ids.value()[1] == -1 ? best_gemm2_profile : mGemm2Profiles.at(profile_ids.value()[1]); } mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); } diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index 6f2ce0f93e6..f822dd15a56 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -370,8 +370,8 @@ protected: float mSparseMixerEpsilon = 0.2f; - // Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths - bool mUseDeterministicHopperReduce = true; + // Default this to false. This only matters for K>2, and so by doing this we will test the fused and unfused paths + bool mUseFusedFinalize = false; // Disable this for long running tests to speed up runtime bool mIsLongTest = false; @@ -456,7 +456,7 @@ protected: { managed_buffers.clear(); - mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce; + mMoERunner.use_fused_finalize_ = k < 3 || mUseFusedFinalize; mHiddenSize = hidden_size; mInterSize = hidden_size * mInterSizeFraction; @@ -1087,9 +1087,9 @@ protected: return std::tuple{(void*) weight_1, (void*) weight_2, bias_1, bias2_ptr, scale_1, scale_2, scale_3}; } - auto getFilteredConfigs(int sm) + auto getFilteredConfigs(int sm, MoeGemmId gemm_id) { - auto tactics = mMoERunner.getTactics(); + auto tactics = mMoERunner.getTactics(gemm_id); if (sm == 89 || sm >= 120) { // Filter some unsupported configs for L40S @@ -1120,17 +1120,25 @@ protected: auto selectTacticsForArch(int sm) { bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT; - auto tactics = getFilteredConfigs(sm); - auto it = std::find_if(tactics.begin(), tactics.end(), + bool is_finalize_fusion = is_tma_warp_specialized && mUseFusedFinalize; + auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1); + auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2); + auto it1 = std::find_if(tactics1.begin(), tactics1.end(), [is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; }); - if (it == tactics.end()) + auto it2 = std::find_if(tactics2.begin(), tactics2.end(), + [is_tma_warp_specialized, is_finalize_fusion](auto& c) { + return c.is_tma_warp_specialized == is_tma_warp_specialized + && c.using_fused_finalize == is_finalize_fusion; + }); + if (it1 == tactics1.end() || it2 == tactics2.end()) { // Fall back to any tactic std::cout << "WARNING: Could not find config for sm version " << sm << std::endl; - return std::pair{tactics[0], tactics[0]}; + it1 = (it1 == tactics1.end()) ? tactics1.begin() : it1; + it2 = (it2 == tactics2.end()) ? tactics2.begin() : it2; } - return std::pair(*it, *it); + return std::pair(*it1, *it2); } using ConfigsToTestVec = std::vectorget(); auto tactic1 = mInternalSelectedConfig1; auto tactic2 = mInternalSelectedConfig2; - if (!tactic1) + if (!tactic1 || !tactic2) { int sm = getSMVersion(); - std::tie(tactic1, tactic2) = selectTacticsForArch(sm); + std::tie(tactic1, tactic2) = selectTacticsForArch(sm, mUseFusedFinalize); } ASSERT_TRUE(tactic1.has_value()); ASSERT_TRUE(tactic2.has_value()); @@ -1630,7 +1638,7 @@ void MixtureOfExpertsTest::BasicPermuteTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1749,7 +1757,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias) TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic) { - this->mUseDeterministicHopperReduce = false; + this->mUseFusedFinalize = true; // Just test case 3, cases 1&2 always use the fused paths this->BasicPermuteTest(3); } @@ -1897,7 +1905,7 @@ void MixtureOfExpertsTest::ParallelismTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1913,7 +1921,7 @@ void MixtureOfExpertsTest::ParallelismTest( { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -2103,12 +2111,13 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::SwigluBias}; if (this->NVFP4) activation_pool = {ActivationType::Relu}; - auto configs = this->getFilteredConfigs(getSMVersion()); + auto configs1 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_1); + auto configs2 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_2); for (auto const activation_type : activation_pool) { - for (auto conf1 : configs) + for (auto conf1 : configs1) { - for (auto conf2 : configs) + for (auto conf2 : configs2) { auto name1 = genConfigName(conf1); auto name2 = genConfigName(conf2); diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index da4df91f693..aa1b250b3a1 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -453,7 +453,8 @@ def _profile_runners( p.name for p in inspect.signature(runner.forward).parameters.values() } - valid_tactics = runner.get_valid_tactics(input_tensors, profile) + valid_tactics = runner.get_valid_tactics(input_tensors, profile, + **kwargs) if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: runner( input_tensors, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index bd946343b09..719baaa450d 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -85,8 +85,9 @@ def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, + gemm_idx: int, ) -> List[int]: - return range(self.fused_moe_runner.get_tactic_num()) + return range(self.fused_moe_runner.get_tactic_num(gemm_idx)) def forward( self, From 382a94ef39f6c85b265434d6e4f44492d0a7fa6c Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 15 Aug 2025 11:48:01 +1200 Subject: [PATCH 02/13] Fix build issues Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- ...ixtureOfExpertsBackendBenchmarkLauncher.cu | 17 +++++++++----- .../include/cutlass_extensions/gemm_configs.h | 14 ++++++++--- .../include/moe_gemm_kernels.h | 8 ++++++- .../cutlass_kernels/include/moe_kernels.h | 12 ++++++---- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 10 +++++--- .../include/moe_gemm_kernels.h | 8 +------ .../mixtureOfExperts/mixtureOfExpertsPlugin.h | 1 + cpp/tensorrt_llm/thop/moeOp.cpp | 4 +++- .../kernels/mixtureOfExpertsTest.cu | 23 ++++++++++++------- 9 files changed, 63 insertions(+), 34 deletions(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu index 31c7cc84e3c..c2a447d01a0 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu @@ -50,7 +50,7 @@ auto listAllTactics(MoeGemmId gemm_id) } template -void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) +void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids, MoeGemmId gemm_id) { if (tactic.is_number_integer()) { @@ -60,7 +60,7 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) { for (auto c : tactic) { - parseTacticToVectorID(c, tactic_ids); + parseTacticToVectorID(c, tactic_ids, gemm_id); } } else if (tactic.is_string()) @@ -69,7 +69,7 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) auto tactic_name = tactic.get(); if (tactic_name == "all") { - auto all_tactics = listAllTactics(); + auto all_tactics = listAllTactics(gemm_id); tactic_ids.resize(all_tactics.size()); std::iota(tactic_ids.begin(), tactic_ids.end(), 0); } @@ -291,9 +291,14 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) { printed = true; std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; - auto confs = listAllTactics(); - for (auto c : confs) - std::cerr << c.toString(); + for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) + { + std::cerr << "GEMM " << (int) gemm_id << ":\n"; + auto confs = listAllTactics(gemm_id); + for (auto c : confs) + std::cerr << c.toString(); + std::cerr << std::endl; + } } continue; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 80fc8c02cec..a8b13f353a8 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -432,7 +432,14 @@ struct CutlassGemmConfig bool enableCudaKernel = false; int sm_version = 80; // Use 80 as a catch all for <90 bool is_tma_warp_specialized = false; - bool is_finalize_fusion = false; + + enum class EpilogueFusionType : int + { + NONE, + FINALIZE + }; + + EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE; CutlassGemmConfig() = default; @@ -504,7 +511,7 @@ struct CutlassGemmConfig << "\n\tcluster shape ID: " << (int) cluster_shape << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false") - << "\n\tis_finalize_fusion: " << (is_finalize_fusion ? "true" : "false"); + << "\n\tepilogue fusion type: " << (int) epilogue_fusion_type; } else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { @@ -536,7 +543,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) << ", cluster_shape_enum: " << int(config.cluster_shape) - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false") + << ", epilogue_fusion_type: " << int(config.epilogue_fusion_type); } else { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 3c814851c91..77226417594 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -297,7 +297,13 @@ class MoeGemmRunner static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsTmaWarpSpecialized() const; + + [[nodiscard]] bool supportsTmaWarpSpecialized() const + { + return supportsTmaWarpSpecialized(sm_); + } + + [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 0d0bbd1c068..2c99f3d81cc 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -600,8 +600,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface gemm2_config_ = std::move(gemm2_config); } - static auto& addFinalizeFusionConfigs( - std::vector& configs, bool use_fused_finalize) + static auto addFinalizeFusionConfigs( + std::vector&& configs, bool use_fused_finalize) { if (!use_fused_finalize) return configs; @@ -612,7 +612,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface if (configs[i].is_tma_warp_specialized) { configs.push_back(configs[i]); - configs.back().is_finalize_fusion = true; + configs.back().epilogue_fusion_type + = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; } } return configs; @@ -620,14 +621,15 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface std::vector getTactics(MoeGemmId gemm_id) override { - return addFinalizeFusionConfigs( + return Self::addFinalizeFusionConfigs( moe_gemm_runner_.getConfigs(), gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); } static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused(sm)); + return Self::addFinalizeFusionConfigs( + RunnerType::getConfigs(sm), gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); } void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index c0fa39385da..9d6da3ff709 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2847,8 +2847,10 @@ void CutlassMoeFCRunnerepilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; permuted_token_final_scales_ - = gemm2_config_->using_fused_finalize ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; + = gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe @@ -4005,9 +4007,11 @@ CutlassMoeFCRunner:: bool apply_bias = parallelism_config.tp_rank == 0; auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool using_fused_finalize - = use_fused_finalize_ && gemm2_config_->is_finalize_fusion && !use_w4_groupwise && !use_lora; - TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_config_->using_fused_finalize, + = use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora; + TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion, "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it"); if (using_fused_finalize) { diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h index 06ab4047ad2..3a72417a216 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h @@ -290,13 +290,7 @@ class MoeGemmRunner static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - - [[nodiscard]] bool supportsTmaWarpSpecialized() const - { - return supportsTmaWarpSpecialized(sm_); - } - - [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); + [[nodiscard]] bool supportsTmaWarpSpecialized() const; [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h index cd3aaf52c20..feb1f10cdc7 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h @@ -43,6 +43,7 @@ namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyParams; using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig; using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams; +using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 91cbb9d8c34..abeba273a84 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -48,6 +48,7 @@ namespace common = tensorrt_llm::common; namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId; // Always use public header as it is just utility functions and types using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend; @@ -586,9 +587,10 @@ class FusedMoeRunner : public torch::CustomClassHolder return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids); } - int64_t getTacticNum(int gemm_idx) + int64_t getTacticNum(int64_t const gemm_idx) { std::lock_guard lock(mMutex); + TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2"); return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size(); } diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index f822dd15a56..3564034aaee 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -1120,15 +1120,17 @@ protected: auto selectTacticsForArch(int sm) { bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT; - bool is_finalize_fusion = is_tma_warp_specialized && mUseFusedFinalize; + bool epilogue_fusion_type = is_tma_warp_specialized && mUseFusedFinalize + ? cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE + : cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE; auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1); auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2); auto it1 = std::find_if(tactics1.begin(), tactics1.end(), [is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; }); auto it2 = std::find_if(tactics2.begin(), tactics2.end(), - [is_tma_warp_specialized, is_finalize_fusion](auto& c) { + [is_tma_warp_specialized, epilogue_fusion_type](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized - && c.using_fused_finalize == is_finalize_fusion; + && c.epilogue_fusion_type == epilogue_fusion_type; }); if (it1 == tactics1.end() || it2 == tactics2.end()) { @@ -1175,7 +1177,7 @@ protected: if (!tactic1 || !tactic2) { int sm = getSMVersion(); - std::tie(tactic1, tactic2) = selectTacticsForArch(sm, mUseFusedFinalize); + std::tie(tactic1, tactic2) = selectTacticsForArch(sm); } ASSERT_TRUE(tactic1.has_value()); ASSERT_TRUE(tactic2.has_value()); @@ -1637,8 +1639,9 @@ void MixtureOfExpertsTest::BasicPermuteTest( auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k); runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); - bool should_be_deterministic - = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + bool is_finalize_fusion + = gemm2.epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1904,8 +1907,10 @@ void MixtureOfExpertsTest::ParallelismTest( // Only need to init the inputs on the first iteration runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic - = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1920,8 +1925,10 @@ void MixtureOfExpertsTest::ParallelismTest( else { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic - = !gemm2.is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); From 32bbd59b9f7271f9bc850bd30d03be88d256548d Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:24:22 +1200 Subject: [PATCH 03/13] Fix test issues Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../unit_tests/kernels/mixtureOfExpertsTest.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index 3564034aaee..11ae4273dc6 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -1120,9 +1120,9 @@ protected: auto selectTacticsForArch(int sm) { bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT; - bool epilogue_fusion_type = is_tma_warp_specialized && mUseFusedFinalize - ? cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE - : cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE; + auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize) + ? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE + : tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE; auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1); auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2); auto it1 = std::find_if(tactics1.begin(), tactics1.end(), @@ -1639,8 +1639,8 @@ void MixtureOfExpertsTest::BasicPermuteTest( auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k); runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); - bool is_finalize_fusion - = gemm2.epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { @@ -1908,7 +1908,7 @@ void MixtureOfExpertsTest::ParallelismTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool is_finalize_fusion = gemm2.epilogue_fusion_type - == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) @@ -1926,7 +1926,7 @@ void MixtureOfExpertsTest::ParallelismTest( { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool is_finalize_fusion = gemm2.epilogue_fusion_type - == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) @@ -2092,6 +2092,7 @@ PARALLEL_TEST_SUITE(MixedParallel) TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) { this->mIsLongTest = true; + this->mUseFusedFinalize = true; // True for all cases because we sweep both auto genConfigName = [](auto conf) -> std::string { using namespace tensorrt_llm::cutlass_extensions; @@ -2136,7 +2137,6 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) this->mActType = activation_type; for (auto k : {2, 3}) { - this->mOverrideSelectedConfig1 = conf1; this->mOverrideSelectedConfig2 = conf2; this->BasicPermuteTest(k, this->MINIMUM_ALIGNMENT); From df5a057274c27ddef5a2dba0a73c972048acaab7 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:30:17 +1200 Subject: [PATCH 04/13] Properly profile unfused finalize Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../cutlass_kernels/include/moe_kernels.h | 5 +-- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 31 +++++++++++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 2c99f3d81cc..7898f328a85 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -1000,7 +1000,7 @@ struct GemmProfilerBackend // This will be a unique value for every iteration of warmup and actual bench constexpr static int64_t NUM_ROUTING_SAMPLES = 16; - std::array mTmaInputCache; + std::array, NUM_ROUTING_SAMPLES> mTmaInputCache; QuantParams mQuantParams; bool mBias{}; @@ -1013,7 +1013,8 @@ struct GemmProfilerBackend private: void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); - void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream); + void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream); }; // Populates a buffer with random values for use with MOE benchmarking diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 9d6da3ff709..5645f937fdb 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -4555,14 +4555,22 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr } } -void GemmProfilerBackend::prepareTmaWsInputs( - int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream) +void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream) { if (mSM < 90) { return; } + bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; + if (use_finalize_fusion + && (!mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise + || mGemmToProfile != GemmToProfile::GEMM_2)) + { + return; + } + auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR(type, name) \ @@ -4601,15 +4609,16 @@ void GemmProfilerBackend::prepareTmaWsInputs( size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { - mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, + auto& cache_element = mTmaInputCache[i][use_finalize_fusion]; + cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, workspaces.at("gemm_workspace").first, mScalingType); tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1); int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; - auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input; - auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input; + auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input; + auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input; if (mSM >= 90) { /* GEMM1 */ @@ -4620,9 +4629,7 @@ void GemmProfilerBackend::prepareTmaWsInputs( bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && mWType == nvinfer1::DataType::kUINT8); bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; - bool using_fused_finalize - = mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise; - if (using_fused_finalize) + if (use_finalize_fusion) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; @@ -4664,7 +4671,10 @@ void GemmProfilerBackend::prepare( prepareRouting(num_tokens, workspace_ptr_char, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); - prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream); + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, stream); + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, stream); } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) @@ -4728,7 +4738,8 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac TmaWarpSpecializedGroupedGemmInput tma_ws_input_template; if (tactic.is_tma_warp_specialized) { - tma_ws_input_template = mTmaInputCache[mSampleIndex]; + tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE]; } mInterface->is_profiler = true; From 77afadbb1af71f9393ce0a65da4b06a97ebab578 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:32:01 +1200 Subject: [PATCH 05/13] Fix docs in benchmark Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../mixtureOfExpertsBackendBenchmarkLauncher.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu index c2a447d01a0..bfa4e7e8e02 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu @@ -543,7 +543,7 @@ void help() "Valid tactics are:\n" " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types " "or GPU architectures\n" - " - An array: of integers or objects, forms a list of tactics to sweep\n" + " - An array: of integers, forms a list of tactics to sweep\n" " - The string \"all\": This will sweep through all possible tactics\n" " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. " "Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate " From 1fe1fdc5ead34fd8a7c2067d8339ffc5d76e1e6e Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:48:22 +1200 Subject: [PATCH 06/13] Fix compilation issues Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 5645f937fdb..0a2cdb4e5a5 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -4563,10 +4563,14 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr return; } + bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) + && mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - if (use_finalize_fusion - && (!mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise - || mGemmToProfile != GemmToProfile::GEMM_2)) + bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise + || mGemmToProfile != GemmToProfile::GEMM_2; + if (use_finalize_fusion && finalize_fusion_not_supported) { return; } @@ -4624,11 +4628,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr /* GEMM1 */ gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - - bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); - bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) - && mWType == nvinfer1::DataType::kUINT8); - bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; if (use_finalize_fusion) { assert(!mMinLatencyMode); @@ -4740,6 +4739,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac { tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE]; + TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), "TMA WS input template is not initialized"); } mInterface->is_profiler = true; From e00f3618fa387af54ad5ad64e07ae75b9cfb78a2 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Mon, 18 Aug 2025 11:10:21 +1200 Subject: [PATCH 07/13] Add explanatory comments Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../kernels/cutlass_kernels/moe_gemm/moe_kernels.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0a2cdb4e5a5..ef70b9d45e7 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -4613,6 +4613,8 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { + // Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same pointers to + // save space. auto& cache_element = mTmaInputCache[i][use_finalize_fusion]; cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, workspaces.at("gemm_workspace").first, mScalingType); From 4dc2309f41fc29ba200e990209bd2a2c61e565e2 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Mon, 18 Aug 2025 11:18:46 +1200 Subject: [PATCH 08/13] Benchmark cleanups Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- ...ixtureOfExpertsBackendBenchmarkLauncher.cu | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu index bfa4e7e8e02..8e18694ad74 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu @@ -273,35 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) } // Do this after filtering datatypes as tactics only make sense if we know the data type - bool has_tactic_ids2 = false; std::vector tactic_ids1{}; std::vector tactic_ids2{}; - if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2")) + if (run_config.contains("tactic_id1")) { - has_tactic_ids2 = true; parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1); - parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2); } - - if (tactic_ids1.empty() || tactic_ids2.empty()) + if (run_config.contains("tactic_id2")) { - std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; - static bool printed = false; - if (!printed) - { - printed = true; - std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; - for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) - { - std::cerr << "GEMM " << (int) gemm_id << ":\n"; - auto confs = listAllTactics(gemm_id); - for (auto c : confs) - std::cerr << c.toString(); - std::cerr << std::endl; - } - } - - continue; + parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2); } auto get_or = [&](auto name, auto def) @@ -337,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) } else if (gemm_to_profile == (int) GemmToProfile::GEMM_2) { - if (!has_tactic_ids2) - tactic_ids2 = std::move(tactic_ids1); tactic_ids1 = {-1}; } } @@ -353,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) return val; }; + if (tactic_ids1.empty() || tactic_ids2.empty()) + { + std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; + static bool printed = false; + if (!printed) + { + printed = true; + std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; + for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) + { + std::cerr << "GEMM " << (int) gemm_id << ":\n"; + auto confs = listAllTactics(gemm_id); + for (auto c : confs) + std::cerr << c.toString(); + std::cerr << std::endl; + } + } + + continue; + } + for (auto t1 : tactic_ids1) { - // tactic_ids2 will have one dummy value if has_tactic_ids2 = false for (auto t2 : tactic_ids2) { - if (!has_tactic_ids2) - t2 = t1; - benchmark->Args({num_experts, // get_range("k"), // get_range("hidden_size"), // From 1876c565d83e574d9af7463e9ef2805774c69edc Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:53:28 +1200 Subject: [PATCH 09/13] Add kwargs to signature of all get_valid_tactics calls for consistency Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 45 ++++++------------- .../custom_ops/trtllm_gen_custom_ops.py | 35 +++++---------- tests/unittest/_torch/misc/test_autotuner.py | 13 +++--- 3 files changed, 29 insertions(+), 64 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 719baaa450d..7d0c73364dc 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -81,13 +81,9 @@ def __init__( use_fused_finalize) self.fused_moe_runner = MoERunner.runner_dict[instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - gemm_idx: int, - ) -> List[int]: - return range(self.fused_moe_runner.get_tactic_num(gemm_idx)) + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: + return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"])) def forward( self, @@ -319,11 +315,8 @@ def __init__( self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp8_rowwise_gemm_runner.get_num_configs())) def forward( @@ -404,11 +397,8 @@ def __init__( output_dtype, int(fp4_gemm_type)) self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp4_gemm_runner.get_num_configs())) def forward( @@ -519,11 +509,8 @@ def forward( return out_tensors - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: mat1, mat2, _, _, _ = inputs @@ -736,11 +723,8 @@ def __init__( self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.weight_only_quant_gemm_runner.get_num_configs())) def forward( @@ -814,11 +798,8 @@ def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list( range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs())) diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 2bb780f6ef2..bbee1b8102f 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -122,11 +122,8 @@ def forward( self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, self.do_finalize, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = FP4BlockScaleMoEInputs(*inputs) @@ -409,11 +406,8 @@ def forward( self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = FP8BlockScaleMoEInputs(*inputs) @@ -670,11 +664,8 @@ def forward( self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs) @@ -907,11 +898,8 @@ def forward( self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = E4m3MxE2m1BlockScaleMoEInputs(*inputs) @@ -1123,11 +1111,8 @@ def forward( self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = Bf16MxE2m1BlockScaleMoEInputs(*inputs) diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index c2f5c32141a..5ed816df8d8 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -151,7 +151,8 @@ def test_autotuner_try_block(): class PartialCrashedRunner(TunableRunner): def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, + **kwargs) -> List[int]: return [-1, 0, 1] def forward(self, @@ -226,7 +227,7 @@ def __init__(self, block_size: int, num_warps: int): self.num_warps = num_warps def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: return [-1, 0, 1] def forward(self, @@ -313,11 +314,9 @@ def test_multiple_dynamic_shapes_cache(): class GemmRunnerWithTacticConfigs(TunableRunner): valid_tactic_ids = [-1, 0, 1] - def get_valid_tactics( - self, - inputs: List[FakeTensor], - profile: OptimizationProfile, - ) -> List[Dict[str, int]]: + def get_valid_tactics(self, inputs: List[FakeTensor], + profile: OptimizationProfile, + **kwargs) -> List[Dict[str, int]]: # The simulated delay is not deterministic, so we need to return specific tactics here return [{ "block_size": block_size, From 9fb40d9632cfec48001d60ff2f03e57c2e8b1124 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:09:31 +1200 Subject: [PATCH 10/13] Move finalize setup into base level getConfigs() function Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../include/moe_gemm_kernels.h | 9 +++--- .../cutlass_kernels/include/moe_kernels.h | 25 ++-------------- .../moe_gemm/moe_gemm_template_dispatch.h | 29 ++++++++++++++----- 3 files changed, 28 insertions(+), 35 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 77226417594..16b7838ed64 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -289,11 +289,10 @@ class MoeGemmRunner void moeGemm(GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs); - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getTmaWarpSpecializedConfigs(int sm); - static std::vector getBlackwellConfigs(int sm); - static std::vector getHopperConfigs(int sm); + std::vector getConfigs(bool supports_finalize_fusion) const; + static std::vector getConfigs(int sm, bool supports_finalize_fusion); + static std::vector getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion); static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 7898f328a85..389591e7fea 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -600,36 +600,15 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface gemm2_config_ = std::move(gemm2_config); } - static auto addFinalizeFusionConfigs( - std::vector&& configs, bool use_fused_finalize) - { - if (!use_fused_finalize) - return configs; - - size_t const num_configs = configs.size(); - for (size_t i = 0; i < num_configs; ++i) - { - if (configs[i].is_tma_warp_specialized) - { - configs.push_back(configs[i]); - configs.back().epilogue_fusion_type - = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; - } - } - return configs; - } - std::vector getTactics(MoeGemmId gemm_id) override { - return Self::addFinalizeFusionConfigs( - moe_gemm_runner_.getConfigs(), gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); + return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); } static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return Self::addFinalizeFusionConfigs( - RunnerType::getConfigs(sm), gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); + return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); } void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index a6238883499..0b009f60990 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -471,17 +471,18 @@ void dispatchMoeGemmToCutlass(GroupedGemmInput -std::vector -MoeGemmRunner::getConfigs() const +std::vector MoeGemmRunner::getConfigs( + bool supports_finalize_fusion) const { - return getConfigs(sm_); + return getConfigs(sm_, supports_finalize_fusion); } template std::vector MoeGemmRunner::getConfigs( - int sm) + int sm, bool supports_finalize_fusion) { - std::vector candidate_configs = getTmaWarpSpecializedConfigs(sm); + std::vector candidate_configs + = getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); return candidate_configs; @@ -517,7 +518,8 @@ MoeGemmRunner::getAmpereConfigs(int sm template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs(int sm) +MoeGemmRunner::getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag @@ -554,6 +556,17 @@ MoeGemmRunner::getTmaWarpSpecializedCo std::vector tma_ws_configs = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + if (supports_finalize_fusion) + { + // Duplicate the configs and set the epilogue fusion type to FINALIZE + auto finalize_configs = tma_ws_configs; + std::transform(finalize_configs.begin(), finalize_configs.end(), std::back_inserter(tma_ws_configs), + [](auto& config) + { + config.epilogue_fusion_type = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + return config; + }); + } return tma_ws_configs; } @@ -815,7 +828,9 @@ size_t MoeGemmRunner::calcMaxWorkspace if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8 && !use_wfp4a16) { - auto configs = getTmaWarpSpecializedConfigs(sm_); + // Finalize fusion may not actually be supported by the kernel, + // if they are not we will catch the error and skip them + auto configs = getTmaWarpSpecializedConfigs(sm_, true); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp4) { From 454f5c74a971b24c9d9f0f8cf590776d784a648e Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:24:54 +1200 Subject: [PATCH 11/13] Fix linker errors for getting cutlass kernel configs Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../int8_gemm/int8_gemm_template.h | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h index c44caae0fa7..80048ab3f54 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h @@ -58,6 +58,9 @@ namespace kernels namespace cutlass_kernels { +namespace oss +{ + template void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, @@ -301,6 +304,7 @@ void dispatchGemmToCutlass(int8_t const* A, int8_t const* B, tk::QuantMode quant break; } } +} // namespace oss template CutlassInt8GemmRunner::CutlassInt8GemmRunner() @@ -326,18 +330,18 @@ void CutlassInt8GemmRunner::dispatchToArch(int8_t const* A, int8_t const* B, TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if (mSm >= 72 && mSm < 75) { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); + oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, + workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); } else if (mSm >= 75 && mSm < 80) { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); + oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, + workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); } else if (mSm >= 80 && mSm <= 90 || mSm >= 120) { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); + oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, + workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); } else { From 71f927207af3a125a553c5a69e23c8ee5df9ef77 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 21 Aug 2025 11:01:12 +1200 Subject: [PATCH 12/13] Update internal cutlass artefacts Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../tensorrt_llm_internal_cutlass_kernels_static.tar.xz | 4 ++-- .../internal_cutlass_kernels/aarch64-linux-gnu/version.txt | 4 ++-- .../tensorrt_llm_internal_cutlass_kernels_static.tar.xz | 4 ++-- .../internal_cutlass_kernels/x86_64-linux-gnu/version.txt | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 08cd9b6f664..5ebd5f7ebe7 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf -size 67051604 +oid sha256:d6a3f6adef11003f794a6cec1235d0c622ead71b4e801a89866e91dfd91bb30c +size 67053244 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index 8b500f5c970..b93f46ea6d0 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 9c0a42825905952beaf9b35d5a35d58de1a123fa +317a25037093a6f3d156ffa58a68bce53071ef68dacdcb04cc0aaeea80b64e76 libtensorrt_llm_internal_cutlass_kernels_static.a +commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index f1a6b9dc88a..bd075284607 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d -size 66872936 +oid sha256:489fb557b78062efedd1514f2995fafb9216bb0e0068a550e86763efb9d5eee9 +size 66874608 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index 4af58b0800e..3c053c1a910 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a -commit 9c0a42825905952beaf9b35d5a35d58de1a123fa +5a31acd0fb1415196bff71fa4a8d1dded147e15ea10821cc46c85684c66986ee libtensorrt_llm_internal_cutlass_kernels_static.a +commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b From 2d7fb1b14dfcd4a1e4d36aa899ae0294d85d8f34 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 21 Aug 2025 11:28:44 +1200 Subject: [PATCH 13/13] Remove unneeded namespace changes Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../int8_gemm/int8_gemm_template.h | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h index 80048ab3f54..ef06abceee6 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h @@ -57,10 +57,6 @@ namespace kernels { namespace cutlass_kernels { - -namespace oss -{ - template void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, @@ -304,7 +300,6 @@ void dispatchGemmToCutlass(int8_t const* A, int8_t const* B, tk::QuantMode quant break; } } -} // namespace oss template CutlassInt8GemmRunner::CutlassInt8GemmRunner() @@ -330,18 +325,18 @@ void CutlassInt8GemmRunner::dispatchToArch(int8_t const* A, int8_t const* B, TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if (mSm >= 72 && mSm < 75) { - oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, - workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); } else if (mSm >= 75 && mSm < 80) { - oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, - workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); } else if (mSm >= 80 && mSm <= 90 || mSm >= 120) { - oss::dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, - workspacePtr, workspaceBytes, gemmConfig, stream, occupancy); + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); } else {