diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index e321a4b07b3..59c9d2fffe4 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE() return forceDeterministic; } +bool getEnvMOEDisableFinalizeFusion() +{ + static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION"); + return moeDisableFinalizeFusion; +} + bool getEnvForceDeterministicAttention() { static bool const forceDeterministic diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index b4921af40e9..f5c0d854ba4 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -86,6 +86,9 @@ bool getEnvForceDeterministic(); // Force deterministic behavior for MoE plugin. bool getEnvForceDeterministicMOE(); +// Disable finalize fusion in MoE plugin +bool getEnvMOEDisableFinalizeFusion(); + // Force deterministic behavior for attention plugin. bool getEnvForceDeterministicAttention(); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp deleted file mode 100644 index 09ae3e013ee..00000000000 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ /dev/null @@ -1,568 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" - -#include "cute/numeric/numeric_types.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" - -#include "cutlass_extensions/arch/copy_red_global.hpp" -#include "cutlass_extensions/util/gather_tensor.hpp" - -#include "cutlass/epilogue/collective/builders/sm90_builder.inl" -#include "cutlass/epilogue/collective/builders/sm90_common.inl" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class EpilogueMoeFusedFinalize -{ -public: - using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; - using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; - - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementIntermediate = typename ThreadEpilogueOp::ElementD; - - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - static_assert(!is_same_v, "Stride C must be a pointer"); - static_assert(is_same_v, "Stride D must not be a pointer"); - - using CopyAtomR2S = Copy_Atom; - using CopyAtomS2R = Copy_Atom; - using CopyAtomR2G = Copy_Atom; - static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; - - using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - - struct SharedStorage - { - alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; - }; - - struct TensorMapStorage - { - }; - - struct Arguments - { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C{}; - StrideC dC{}; - ElementD* ptr_D{}; - StrideD dD{}; - ElementBias const* ptr_bias; - StrideBias dBias{}; - ElementScale const* ptr_scale; - StrideScale dScale{}; - int64_t const* group_offset{}; - int32_t const* scatter_index{}; - cutlass::FastDivmod num_rows_in_final_output; - }; - - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) - { - return args; - } - - template - static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) - { - return 0; - } - - template - static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) - { - return cutlass::Status::kSuccess; - } - - template - CUTLASS_HOST_DEVICE static bool can_implement( - [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) - { - bool implementable = true; - if (problem_shape.is_host_problem_shape_available()) - { - // Check alignment for all problem sizes - for (int i = 0; i < problem_shape.groups(); i++) - { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); - } - } - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " - "reduction instruction.\n"); - } - return implementable; - } - - CUTLASS_HOST_DEVICE - EpilogueMoeFusedFinalize(Params const& params_) - : params(params_) - { - } - - CUTLASS_DEVICE - bool is_source_needed() - { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return params.ptr_C != nullptr - && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); - } - - template - CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - auto synchronize = [&]() - { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Batches are managed by using appropriate pointers to C and D matrices - int32_t const mock_L = 1; - int32_t const mock_l_coord = 0; - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op(params.thread, l_coord); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - - Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); - Tensor sD = as_position_independent_swizzle_tensor(sD_); - - // Function to scatter output rows - auto& num_rows = params.num_rows_in_final_output; - auto read_scatter_map = tensorrt_llm::cutlass_extensions::IndexedGather( - make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); - auto get_scatter_idx = [&](auto i) - { - auto scatter = read_scatter_map(i); - int quot, rem; - num_rows(quot, rem, scatter); - return rem; - }; - - // Represent the full output tensor - ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; - auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) - Tensor mD_mnl = tensorrt_llm::cutlass_extensions::make_gather_tensor( - make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) - - // Use fake shape for bias, it doesn't matter - bool const is_bias_needed = params.ptr_bias != nullptr; - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); - Tensor mScale_mnl = make_tensor( - make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); - - Tensor gC_mnl - = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl - = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - Tensor gBias_mnl - = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gScale_mnl - = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) - Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) - - Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Get the smallest tiled copy we can use to retile the accumulators - TiledCopy tiled_copy_C_atom - = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); - TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); - - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) - - // Make a tiled copy vectorized along major direction of D - auto tiled_s2r = [&]() - { - if constexpr (cutlass::gemm::detail::is_k_major()) - { - constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride, _1>>{}, - Layout>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major()) - { - constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride<_1, Int>>{}, - Layout, _1>>{}); - } - else - { - static_assert(cute::is_void_v, "Unsupported D gmem layout."); - } - }(); - - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - - // Make an identity coordinate tensor for predicating our output MN tile - Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // epilogue subtile loop - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) - { - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) - { - int mma_m = (epi_m * epi_tile_m) / mma_tile_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); - - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) - { - tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); - } - - copy(tiled_r2s, tRS_rD, tRS_sD); - synchronize(); - - copy(tiled_s2r, tSR_sD, tSR_rD); - synchronize(); - - Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); - Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); - Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); - Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); - Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); - - if (epilogue_op.is_source_needed()) - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - -namespace detail -{ - -template -constexpr auto get_vectorized_atomic_add_op() -{ - using namespace cute; - - auto constexpr MaxVecSize = size(MaxVec{}); - - if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_F16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_F16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM70_RED_ADD_NOFTZ_F16x2{}; - } - else - { - return SM70_RED_ADD_NOFTZ_F16{}; - } - } - else if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM90_RED_ADD_NOFTZ_BF16x2{}; - } - else - { - return SM90_RED_ADD_NOFTZ_BF16{}; - } - } - else - { - // non-vectorized atomic add for all other types until supported - return TypedAtomicAdd{}; - } -} - -} // namespace detail - -template -struct EpilogueMoeFusedFinalizeBuilder -{ - - // assuming cooperative kernel schedule - using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); - using EpilogueTile = Shape<_128, EpiTileN>; - - // Output of linear combination is ElementCompute instead of ElementD - // since we will be doing more computate on it, no need to cast yet. - using ThreadEpilogueOp - = cutlass::epilogue::thread::LinearCombination; - - using SmemLayoutAtomD - = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); - using CopyAtomR2S - = decltype(detail::sm90_get_smem_store_op_for_accumulator()); - using CopyAtomS2R = DefaultCopy; - using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); - - template - struct TmaWarpSpecializedAdapterWithSmemStorageImpl : Base - { - // We need to override this one using declaration because otherwise we double up on the smem - using TensorMapStorage = typename EpilogueOp::TensorMapStorage; - - // using Base = detail::Sm90TmaWarpSpecializedAdapter; - - CUTLASS_HOST_DEVICE - TmaWarpSpecializedAdapterWithSmemStorageImpl( - typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) - : Base(params) - { - } - - CUTLASS_DEVICE auto load_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx) - { - return cute::make_tuple(nullptr); - } - - CUTLASS_DEVICE auto store_init([[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] int32_t sm_count, - [[maybe_unused]] int32_t sm_idx, [[maybe_unused]] int32_t warp_group_idx) - { - return cute::make_tuple(nullptr); - } - - // Dummy methods to perform different parts of TMA/Tensormap modifications - - template - CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] ProblemShapeMNKL problem_shape, - [[maybe_unused]] int32_t next_batch, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template - CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormaps, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t warp_group_idx) - { - } - - template - CUTLASS_DEVICE void tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) - { - } - }; - - template - using TmaWarpSpecializedAdapterWithSmemStorage = TmaWarpSpecializedAdapterWithSmemStorageImpl< - std::conditional_t= 100, detail::Sm100TmaWarpSpecializedAdapter, - detail::Sm90TmaWarpSpecializedAdapter>, - EpilogueOp>; - - using CollectiveOp = TmaWarpSpecializedAdapterWithSmemStorage< - EpilogueMoeFusedFinalize>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp new file mode 100644 index 00000000000..3571906a64f --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp @@ -0,0 +1,547 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +template < + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm90ScatterPtrArray { + + using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{}))))); + using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{})); + + using ElementIndex = int32_t; + // TODO: more generic treatment, or pass StrideIndex via template param? + using StrideIndex = conditional_t(), Stride<_0,_1,_0>, Stride<_1,_0,_0>>; + + struct SharedStorage {}; + + struct Arguments { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + struct Params { + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store + bool use_reduction = true; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return { + args.ptr_out, + args.dOut, + args.ptr_index, + cutlass::FastDivmod(args.index_modulo), + args.use_reduction + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class ArgsTuple + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple) + : args_tuple(std::move(args_tuple)) {} + + ArgsTuple args_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rOut_frg = recast>(coalesce(tC_rOut)); // (EPI_V) + tC_rOut_frg(epi_v) = convert_input(frg_input); + + return tC_rOut_frg(epi_v); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + Tensor byte_buffer = recast(reduction_buffer); + static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v >= cosize(SmemLayout{}) * sizeof_bits_v, + "Not enough space in scratch smem buffer"); + + Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr(byte_buffer.data())), SmemLayout{})); + + auto thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut); + Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut); + + auto thread_r2g = tiled_r2g_red.get_slice(thread_idx); + Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n); + Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut); + Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers + + // sanity check for register reuse + CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G"); + + copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi); + sync_fn(); + copy(tRG_sOut_epi, tRG_rOut_epi); + + auto residue = residue_cD; // capturing structured bindings is a C++20 feature + Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n); + auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); }); + + if (use_reduction) { + copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi); + } + else { + copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi); + } + } + }; + + template + static constexpr auto get_reduction_op() + { + using namespace cute; + + // For now only support red.add + if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } + } + + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); }; + Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1) + Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1) + Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor tC_gOut = sm90_partition_for_epilogue(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tC_rOut = make_tensor(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + + // Vectorization must not exceed alignment and also the number of values per thread in the tile + int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); + int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); + int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + + // Choose the largest available red.global op and an st.global op with matching vectorization + using CopyOpR2GRed = decltype(get_reduction_op()); + using CopyOpR2GStg = UniversalCopy::NumValSrc * sizeof_bits_v>>; + + auto make_tiled_r2g = [&](auto copy_op) + { + using CopyAtomR2G = Copy_Atom; + constexpr int VecSize = CopyAtomR2G::NumValSrc; + if constexpr (cutlass::gemm::detail::is_k_major()) { + constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }; + + auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{}); + auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{}); + + // Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy, + // ensure they have matching layouts/tilers + using TiledR2GRed = decltype(tiled_r2g_red); + using TiledR2GStg = decltype(tiled_r2g_stg); + static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc"); + static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst"); + static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV"); + static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN"); + + auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx); + Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + auto args_tuple = make_tuple( + cute::move(tC_rOut), + tiled_r2s, + tRG_gOut, + tRG_cD, + tiled_r2g_red, + tiled_r2g_stg, + params_ptr->use_reduction, + args.thread_idx, + args.residue_cD); + + return ConsumerStoreCallbacks(std::move(args_tuple)); + } +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBiasPerColScaleScatter + : ScaledAccPerRowBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = alpha * acc + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale + Sm90ScaledAccPerRowBiasPtrArray // alpha * acc + bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerRowBiasPerColScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_0,_1,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + bool use_reduction = true; + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +} // namespace cutlass::epilogue::fusion + +// clang-format on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index f9355860bec..fe75687e368 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100 CtaShape128x256x128B, CtaShape128x128x256B, CtaShape128x256x256B, - - // M=256 - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B, }; enum class CutlassTileConfigSM120 diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp index e529ffc1faa..a83bf6a0830 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -19,7 +19,7 @@ #include "cute/tensor.hpp" #include "cute/util/print.hpp" -namespace tensorrt_llm::cutlass_extensions +namespace cutlass::util { /// Function object that applies an index to its argument @@ -81,7 +81,7 @@ struct CustomStride template CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { - return CustomStride(s.func_, safe_div(s.stride_, div)); + return CustomStride(s.func_, cute::safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral @@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } -} // namespace tensorrt_llm::cutlass_extensions +} // namespace cutlass::util namespace cute { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 9e3bbaa32b7..837b916f366 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -377,72 +377,62 @@ std::vector get_candidate_configs_sm100(CutlassGemmConfig::Ca if (config & CutlassGemmConfig::GROUPED_GEMM) { std::vector candidate_configs; - if ((config & CutlassGemmConfig::FP4_ONLY) != 0) + if (config & CutlassGemmConfig::FP4_ONLY) { candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, - MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); - candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); return candidate_configs; } - for (int cluster_m = 1; cluster_m <= 2; cluster_m++) + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1}, + {CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1}, + {CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1}, + }; + + if (config & CutlassGemmConfig::FP8_ONLY) { - bool Is2SM = cluster_m == 2; - for (int cluster_n = 1; cluster_n <= 2; cluster_n++) - { - std::vector base = {// M=128 - CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; - - if (Is2SM) - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); - base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); - } - - std::vector twosm = {// M=256 - CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; - std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); - } - else - { - if (cluster_n == 1) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); - if ((config & CutlassGemmConfig::FP8_ONLY) != 0) - { - base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); - } - } - - std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, - CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x64x128B}; - std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); - } + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1}); + // TODO: re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 }); + } - constexpr std::array cluster_shapes - = {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, - std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}; - auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; - for (auto tile : base) - { - CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; - candidate_configs.push_back(config); - } - } + for (auto [tile, cluster] : tile_configs) + { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); } return candidate_configs; } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 1237884d13c..3c814851c91 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -37,11 +37,6 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -template -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); -} template struct GroupedGemmInput @@ -72,8 +67,6 @@ struct GroupedGemmInput struct TmaWarpSpecializedGroupedGemmInput { - template - using TransposeStride = decltype(transpose_stride(T{})); template using TransposeLayoutTag = std::conditional_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; @@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput using StrideB = std::remove_pointer_t>; // Use A because they will be swapped using StrideC = std::remove_pointer_t>; + using StrideD = std::remove_pointer_t>; #ifdef ENABLE_FP8 template @@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput StrideC* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t; void* ptr_final_output = nullptr; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion @@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); @@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput return stride_a != nullptr && ptr_a != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction); std::string toString() const; }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index ca256ae0d6b..7d592bed0e4 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -495,7 +495,8 @@ class CutlassMoeFCRunnerInterface void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) = 0; virtual std::pair @@ -512,13 +513,13 @@ class CutlassMoeFCRunnerInterface virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template (bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), stream); + reinterpret_cast(gemm2_output), router_scales, permuted_row_to_unpermuted_row, + stream); } std::pair @@ -760,7 +763,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, @@ -790,8 +794,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4_groupwise; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_ + && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index d5f0b198fd8..e92186a3f5c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,21 +26,14 @@ #include "cute/tensor.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -189,17 +182,19 @@ using SafeBF16 = void; TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, \ cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) \ { \ + using ArchTag = cutlass::arch::ArchTag_; \ constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ + constexpr static bool IsMXFPX = MXFPX_; \ + constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \ + constexpr bool IsSM120 = ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \ + constexpr bool Is2SM = IsBlackwell && (CGA_M_ % 2 == 0); \ /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ - using ArchTag = cutlass::arch::ArchTag_; \ using T = DataType_; \ using WeightType = WeightType_; \ using OutputType = OutputType_; \ using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ + using MmaTileShape = cute::Shape, cute::Int, cute::Int>; \ using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - constexpr static bool IsMXFPX = MXFPX_; \ - \ if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && ArchTag::kMinComputeCapability >= 90 \ && ArchTag::kMinComputeCapability < 100) \ { \ @@ -217,18 +212,15 @@ using SafeBF16 = void; TLLM_THROW( \ "Please recompile with support for blackwell by passing 120-real as an arch to build_wheel.py."); \ } \ - else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ + else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) \ { \ using namespace cute; \ /* Helper class for defining all the cutlass types \ // template \ + // typename MmaTileShape, typename ClusterShape, bool BIAS, EpilogueFusion FUSION> \ // struct TmaWarpSpecializedGroupedGemmInfo \ { */ \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value \ && cutlass::platform::is_same::value; \ constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ @@ -308,8 +300,8 @@ using SafeBF16 = void; // units of elements (up to 16 bytes)*/ \ \ /* D matrix configuration */ \ - using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \ - using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \ + using LayoutD = TmaWarpSpecializedGroupedGemmInput::LayoutD; \ + using StrideD = TmaWarpSpecializedGroupedGemmInput::StrideD; \ constexpr static int AlignmentD \ = 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment of D matrix \ // in units of elements (up to 16 bytes) */ \ @@ -327,30 +319,24 @@ using SafeBF16 = void; // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations \ // >;*/ \ - using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \ + using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ \ - constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ using EpilogueScheduleSM100 = std::conditional_t; \ using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ - using EpilogueScheduleBW = std ::conditional_t; \ + using EpilogueScheduleBW = std::conditional_t; \ using EpilogueSchedule = std::conditional_t; \ \ - using EpilogueTileShapeSm90 = TileShape; \ - using AtomClusterDiv = std::conditional_t; \ - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape{})); \ - using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using EpilogueTileShape = std::conditional_t; \ using EpilogueElementC = std::conditional_t; \ using EpilogueTensorOp = std::conditional_t; \ - using EpilogueSubTile \ - = std::conditional_t, cutlass::epilogue::collective::EpilogueTileAuto>; \ + using EpilogueSubTile = std::conditional_t, cutlass::epilogue::collective::EpilogueTileAuto>; \ /* Epilogue For Default Finalize */ \ using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; \ \ /* Epilogue For Fused Finalize */ \ - using CollectiveEpilogueFinalize = \ - typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< /**/ \ - Arch, EpilogueTileShape, /**/ \ - ElementCSafe, StrideC*, /**/ \ - ElementFinalOutput, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, /**/ \ - ElementAccumulator, /**/ \ - ElementAccumulator, /**/ \ - ElementBias, TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, /**/ \ - ElementRouterScales, \ - TmaWarpSpecializedGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales /**/ \ - >::CollectiveOp; \ + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder /**/ \ + >::CollectiveOp; \ \ using CollectiveEpilogue = std::conditional_t; \ @@ -405,16 +390,12 @@ using SafeBF16 = void; using MainloopElementA = std::conditional_t; \ using MainloopElementB = std::conditional_t; \ \ - using MainloopTileShapeSm90 = TileShape; \ - using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using MainloopTileShape = std::conditional_t; \ - \ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder::CollectiveOp; \ \ using GemmKernel = cutlass::gemm::kernel::GemmUniversal; \ /*}; \ - \ \ + // \ // using namespace cute; \ // using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;; \ // \ // using ElementAccumulator = typename GemmInfo::ElementAccumulator; \ @@ -478,7 +459,7 @@ using SafeBF16 = void; TLLM_CHECK(tma_ws_input.ptr_a); \ TLLM_CHECK(tma_ws_input.ptr_b); \ \ - auto make_mainloop_params = [&]() -> MainloopArguments \ + MainloopArguments const mainloop_args = [&] \ { \ if constexpr (IsBlockScaled) \ { \ @@ -498,67 +479,46 @@ using SafeBF16 = void; reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a); \ } \ - }; \ - \ - auto const mainloop_params = make_mainloop_params(); \ - \ - using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ - using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ - auto make_epilogue_scalars = [&]() \ + }(); \ + using FusionArguments = typename CollectiveEpilogue::FusionCallbacks::Arguments; \ + FusionArguments fusion_args = [&] \ { \ - if constexpr (IsBlackwell) \ - { \ - return construct_if_true(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, nullptr, \ - tma_ws_input.alpha_scale_ptr_array, nullptr, \ - cute::Shape<_0, _0, int64_t>{ \ - cute::_0{}, cute::_0{}, (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ - cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ - } \ - else if (tma_ws_input.alpha_scale_ptr_array) \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - return construct_if_true(tma_ws_input.alpha_scale_ptr_array); \ + auto epi_params = tma_ws_input.fused_finalize_epilogue; \ + return construct_if_true( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_1, _0, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.use_reduction); \ } \ else \ { \ - return construct_if_true(ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ + return construct_if_true( \ + ElementAccumulator(1), ElementAccumulator(0), nullptr, nullptr, \ + tma_ws_input.alpha_scale_ptr_array, nullptr, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, Stride<_0, _0, int64_t>{}); \ } \ - }; \ - auto epilogue_scalars = make_epilogue_scalars(); \ - /* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \ - auto make_epi_args = [&]() \ - { \ - static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ - "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ + }(); \ \ - if constexpr (FUSION == EpilogueFusion::NONE) \ + using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ + EpilogueArguments epilogue_args = [&] \ + { \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) \ { \ - auto epi_params = tma_ws_input.default_epilogue; \ - return construct_if_true(epilogue_scalars, \ - nullptr, tma_ws_input.stride_c, reinterpret_cast(epi_params.ptr_d), \ - epi_params.stride_d); \ + return construct_if_true( \ + fusion_args, nullptr, nullptr, nullptr, nullptr); \ } \ - else if constexpr (FUSION == EpilogueFusion::FINALIZE) \ + else \ { \ - /* Parameters for fused finalize */ \ - auto epi_params = tma_ws_input.fused_finalize_epilogue; \ - return construct_if_true( \ - epilogue_scalars, /* Parameters to underlying epilogue */ \ - nullptr, tma_ws_input.stride_c, /* C params */ \ - reinterpret_cast(epi_params.ptr_final_output), \ - epi_params.stride_final_output, /* D (output) params */ \ - reinterpret_cast(epi_params.ptr_bias), \ - epi_params.stride_bias, /* Bias params */ \ - epi_params.ptr_router_scales, epi_params.stride_router_scales, /* Router scales */ \ - epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token in the \ - router scales */ \ - epi_params.ptr_source_token_index, /* Index of the source token to sum into */ \ - epi_params.num_rows_in_final_output /* Number of tokens in the output buffer */ \ - ); \ + return construct_if_true(fusion_args, \ + nullptr, nullptr, reinterpret_cast(tma_ws_input.ptr_d), tma_ws_input.stride_d); \ } \ - }; \ - EpilogueArguments const epilogue_params = make_epi_args(); \ + }(); \ /* EpilogueArguments const epilogue_params = make_epi_args( \ // tma_ws_input, epilogue_scalars \ @@ -568,7 +528,7 @@ using SafeBF16 = void; 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ \ const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - tma_ws_input.shape_info, mainloop_params, epilogue_params, hw_info, scheduler_args}; \ + tma_ws_input.shape_info, mainloop_args, epilogue_args, hw_info, scheduler_args}; \ \ size_t calculated_ws_size = gemm.get_workspace_size(args); \ TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index 651b7f14060..719824c4c6c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -197,8 +197,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.int4_groupwise_params.ptr_s_a), hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; *workspace_size = gemm.get_workspace_size(args); return; @@ -211,8 +210,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.int4_groupwise_params.ptr_s_a), hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.ptr_d), hopper_inputs.stride_d}, hw_info}; if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index d9df31513f3..40496a6a0eb 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -138,11 +138,11 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn } } -template +template constexpr bool are_tile_shapes_supported_sm100() { using namespace cute; - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); + // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); @@ -353,6 +353,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG { switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) SHAPE_CASE(100, 64, 64, 128) SHAPE_CASE(100, 64, 128, 128) SHAPE_CASE(100, 64, 256, 128) @@ -363,13 +364,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG SHAPE_CASE(100, 128, 128, 128) SHAPE_CASE(100, 128, 256, 128) - SHAPE_CASE(100, 256, 64, 128) - SHAPE_CASE(100, 256, 128, 128) - SHAPE_CASE(100, 256, 256, 128) - // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) - // SHAPE_CASE(100, 256, 256, 64) DEFAULT_CASE(100) } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu index 485c19496f3..b49dfec9992 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu @@ -27,14 +27,14 @@ namespace tensorrt_llm::kernels::cutlass_kernels { -std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( +std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( int num_experts, FpXBlockScalingType scaling_type) { size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; size_t stride_a_size = sizeof(StrideA) * num_experts; size_t stride_b_size = sizeof(StrideB) * num_experts; size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + size_t stride_d_size = sizeof(StrideD) * num_experts; size_t ptr_buf_size = sizeof(void*) * num_experts; size_t scale_buf_size = sizeof(float*) * num_experts; @@ -53,9 +53,12 @@ std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts; size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts; + size_t ptr_token_map_size = sizeof(int**) * num_experts; + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size, sf_a_size, sf_b_size, stride_sf_a_size, - stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size}; + stride_sf_b_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, int4_groupwise_stride_sf_a_size, + ptr_buf_size, scale_buf_size, ptr_token_map_size}; } size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, FpXBlockScalingType scaling_type) @@ -68,7 +71,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i size_t gemm_workspace_size, FpXBlockScalingType scaling_type) { auto buffers = workspaceBuffers(num_experts, scaling_type); - std::array pointers{}; + std::array pointers{}; TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); for (int i = 0; i < buffers.size(); i++) { @@ -82,12 +85,12 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i stride_a = reinterpret_cast(pointers[1]); stride_b = reinterpret_cast(pointers[2]); stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); + stride_d = reinterpret_cast(pointers[4]); ptr_a = reinterpret_cast(pointers[5]); ptr_b = reinterpret_cast(pointers[6]); ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + ptr_d = reinterpret_cast(pointers[8]); alpha_scale_ptr_array = reinterpret_cast(pointers[9]); @@ -103,28 +106,24 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i int4_groupwise_params.ptr_s_a = reinterpret_cast(pointers[15]); int4_groupwise_params.stride_s_a = reinterpret_cast(pointers[16]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast(pointers[17]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast(pointers[18]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast(pointers[19]); + this->gemm_workspace = reinterpret_cast(gemm_workspace); this->gemm_workspace_size = gemm_workspace_size; } -void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens) +void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams( + void* final_output, int hidden_size, int num_output_tokens, bool use_reduction) { fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output - = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias - = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride( + FusedFinalizeEpilogue::StrideFinalOutput{}, cute::make_shape(hidden_size, num_output_tokens, 1)); fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; + fused_finalize_epilogue.use_reduction = use_reduction; } std::string TmaWarpSpecializedGroupedGemmInput::toString() const @@ -143,16 +142,13 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const ss << "Final Output: " << (PrintType) fused_finalize_epilogue.ptr_final_output; ss << " with Stride: " << fused_finalize_epilogue.stride_final_output; ss << ",\nBias: " << (PrintType) fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " << (PrintType) fused_finalize_epilogue.ptr_expert_first_token_offset; ss << ", Source Map: " << (PrintType) fused_finalize_epilogue.ptr_source_token_index; } else { - ss << "Ptr D: " << (PrintType) default_epilogue.ptr_d; - ss << " with Stride: " << (PrintType) default_epilogue.stride_d; + ss << "Ptr D: " << (PrintType) ptr_d; + ss << " with Stride: " << (PrintType) stride_d; } ss << '\n'; ss << "Alpha scale ptr: " << (PrintType) alpha_scale_ptr_array << "\n"; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index ae4c25f379f..730840717c2 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1165,8 +1165,8 @@ __device__ void computeTmaWarpSpecializedInputStrides( } if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); + layout_info.stride_d[out_idx] = cutlass::make_cute_packed_stride( + TmaWarpSpecializedGroupedGemmInput::StrideD{}, cute::make_shape(gemm_n, gemm_m, 1)); } if (layout_info.int4_groupwise_params.enabled) { @@ -1185,7 +1185,8 @@ template __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m, int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in, WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale, - ScaleBiasType const* bias, OutputType* output, int64_t const out_idx) + ScaleBiasType const* bias, OutputType* output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, int64_t const out_idx) { // The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); @@ -1196,7 +1197,18 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens_before_expert` tokens - layout_info.default_epilogue.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + } + if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) + { + + layout_info.fused_finalize_epilogue.ptr_source_token_index[expert] + = permuted_row_to_unpermuted_row + num_tokens_before_expert; + layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = router_scales + num_tokens_before_expert; + if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr) + { + layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert; + } } if (layout_info.int4_groupwise_params.enabled) { @@ -1219,7 +1231,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir WeightType const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output) + ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, OutputType* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row) { // First, compute the global tid. We only need 1 thread per expert. int const expert = blockIdx.x * blockDim.x + threadIdx.x; @@ -1297,12 +1310,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir gemm1_in, weights1, reinterpret_cast( quant_params.groupwise.fc1.weight_scales), - bias1, gemm1_output, expert); + bias1, gemm1_output, nullptr, nullptr, expert); computeTmaWarpSpecializedInputPointers(layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert, gemm2_in, weights2, reinterpret_cast( quant_params.groupwise.fc2.weight_scales), - bias2, gemm2_output, expert); + bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert); } template @@ -1420,12 +1433,12 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); + layout_info1.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); + layout_info2.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); } } else @@ -1435,10 +1448,10 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali layout_info1.ptr_b[expert] = nullptr; layout_info2.ptr_b[expert] = nullptr; - layout_info1.default_epilogue.ptr_d[expert] = nullptr; + layout_info1.ptr_d[expert] = nullptr; if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; + layout_info2.ptr_d[expert] = nullptr; } } } @@ -2015,8 +2028,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro #define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ template void finalizeMoeRoutingKernelLauncher( \ GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \ - float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \ - int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ + float const* final_scales, int const* unpermuted_row_to_permuted_row, \ + int const* permuted_row_to_unpermuted_row, int const* expert_for_source_row, \ int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \ bool const enable_alltoall, cudaStream_t stream); @@ -3295,9 +3308,8 @@ void CutlassMoeFCRunner:: float const* fp8_dequant2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream) + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -3823,6 +3836,12 @@ CutlassMoeFCRunner:: layout_info2.ptr_c = nullptr; layout_info2.stride_c = nullptr; + layout_info1.fused_finalize_epilogue.ptr_bias = nullptr; + if (!bias2) + { + layout_info2.fused_finalize_epilogue.ptr_bias = nullptr; + } + auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale : use_fp8 ? fp8_dequant1 @@ -3863,7 +3882,7 @@ CutlassMoeFCRunner:: cudaLaunchKernelEx(&config, kernel_instance, expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, bias1, bias2, - gemm1_output, gemm2_output); + gemm1_output, gemm2_output, router_scales, permuted_row_to_unpermuted_row); return std::make_pair(layout_info1, layout_info2); } @@ -3986,15 +4005,15 @@ CutlassMoeFCRunner:: gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; bool apply_bias = parallelism_config.tp_rank == 0; - bool using_hopper_fused_finalize - = !use_deterministic_hopper_reduce_ && gemm2_config_->sm_version == 90 && !use_w4_groupwise && !use_lora; - if (using_hopper_fused_finalize) + auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool using_fused_finalize + = use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora; + if (using_fused_finalize) { assert(min_latency_mode == false); + bool use_reduction = expanded_num_rows > num_rows; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams(final_output, permuted_token_final_scales_, - expert_first_token_offset_, permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr, - hidden_size, num_rows); + gemm2_tma_ws_input.setFinalizeFusionParams(final_output, hidden_size, num_rows, use_reduction); } // fp8_mxfp4 memsets the scaling factors to 1.0f @@ -4028,9 +4047,10 @@ CutlassMoeFCRunner:: gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node, reinterpret_cast(gemm1_input), reinterpret_cast(gemm2_input), fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, - fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_expert_biases, + fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, reinterpret_cast(gemm1_output), - reinterpret_cast(fc2_result_), stream); + reinterpret_cast(fc2_result_), permuted_token_final_scales_, + permuted_row_to_unpermuted_row_, stream); } } @@ -4591,20 +4611,17 @@ void GemmProfilerBackend::prepareTmaWsInputs( gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - bool apply_bias = true; bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && mWType == nvinfer1::DataType::kUINT8); bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; bool using_fused_finalize - = !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && !mMinLatencyMode && !use_w4_groupwise; + = mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise; if (using_fused_finalize) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams(output, token_topk_unpermuted_scales, - expert_first_token_offset, permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr, - mExpertHiddenSize, num_tokens); + gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertHiddenSize, num_tokens, mK > 1); } auto fc1_output_size = isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; @@ -4625,7 +4642,7 @@ void GemmProfilerBackend::prepareTmaWsInputs( fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input, input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, - stream); + token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream); } sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index b6306d3c1de..273eb7d4eb4 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -35,8 +35,7 @@ constexpr bool isValidSM120MOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice return cutlass::platform::is_same::value && cutlass::platform::is_same::value - && cutlass::platform::is_same::value - && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + && cutlass::platform::is_same::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -51,8 +50,7 @@ constexpr bool isValidBlackwellMOESpecialisation() return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) - && cutlass::platform::is_same::value - && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + && cutlass::platform::is_same::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index 07d6ca2df1a..e3fb4461af1 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -212,8 +212,7 @@ def instantiate_operation_tma_warp_specialized(operation): {kernel_sched}, {epi_sched}> ( const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* -); -""" +);""" elif operation.gemm_kind == GemmKind.Grouped: if operation.act_type != operation.weight_type and ( operation.act_type != DataType.e4m3 @@ -261,11 +260,9 @@ def instantiate_operation_tma_warp_specialized(operation): # (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*); # """ instantiation = f""" -#if {guard_act} && {guard_weight}\n - INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n -#endif -""" +#if {guard_act} && {guard_weight} + INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false); +#endif""" return instantiation @@ -276,8 +273,7 @@ def instantiate_operation_sm80(operation): instantiation = f""" template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}> - ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); - """ + ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);""" return instantiation @@ -340,16 +336,13 @@ def write_file(launcher_inl_files, operations, output_file): f.write(content) -from operator import mul, truediv - - def elementwise(x, y, f): return tuple(f(a, b) for (a, b) in zip(x, y)) def is_gemm_op_valid_sm100(op): # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future - tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv) + tile_m, tile_n, _ = op.cta_shape cga_m, cga_n, _ = op.cga_shape # Default shapes @@ -366,13 +359,11 @@ def is_gemm_op_valid_sm100(op): return False # Shapes for fp8 small N shapes - if (op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8) - and (cga_m == 1 and cga_n == 1)): - # todo: double check why this is disable in CUTLASS backend. @yuhan - if tile_m == 128 and tile_n == 8: - return False - else: - return True + if (op.act_type == DataType.e4m3) and (tile_n == 16 + or tile_n == 8) and (cga_m == 1 + and cga_n == 1): + # todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan + return tile_m != 128 or tile_n % 16 == 0 # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: @@ -617,8 +608,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype): cta_shape_k = max_k_bits // GetDataTypeBits(dtype) if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8): cta_shape_k = 256 - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16): - cta_shape_k = 128 return cta_shape_mn + (cta_shape_k, ) @@ -638,7 +627,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize ] cga_shapes = [[1, 1, 1]] @@ -648,7 +637,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args: - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative @@ -661,8 +649,8 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype, - quant_op, epi_tag, cga_tile_shape_mnk, warp_shape, stages, - cga_shape, mainloop_schedule, epi_schedule, epi_fusion) + quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape, + mainloop_schedule, epi_schedule, epi_fusion) operations.append(moe_gemm_operation) return operations @@ -692,7 +680,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize ] cga_shapes = list(product([1, 2], [1, 2], [1])) @@ -708,7 +696,6 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): weight_type = dtype cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype) - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative @@ -729,7 +716,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 88630d62436..08cd9b6f664 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b25eb3a8bc1fae83eb43f9e0cf8fd93bb00f412d6cbd1bf7e2214e878bec3b4a -size 64735372 +oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf +size 67051604 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index ca9cab4b456..8b500f5c970 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -16bae34717995b98ee8cff17bc8ec080c0e1b1aca02e5949be171eb8d40eff39 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 995030f9b86258f3db876df6b1dbc46a7c5dae50 +568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a +commit 9c0a42825905952beaf9b35d5a35d58de1a123fa diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h index c045485f16d..3a72417a216 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h @@ -39,11 +39,6 @@ namespace tensorrt_llm { -template -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); -} // Note update moe.py to match enum class ActivationType @@ -87,8 +82,6 @@ struct GroupedGemmInput struct TmaWarpSpecializedGroupedGemmInput { - template - using TransposeStride = decltype(transpose_stride(T{})); template using TransposeLayoutTag = std::conditional_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; @@ -101,6 +94,7 @@ struct TmaWarpSpecializedGroupedGemmInput using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -122,6 +116,7 @@ struct TmaWarpSpecializedGroupedGemmInput using StrideB = std::remove_pointer_t>; // Use A because they will be swapped using StrideC = std::remove_pointer_t>; + using StrideD = std::remove_pointer_t>; #ifdef ENABLE_FP8 template @@ -148,37 +143,26 @@ struct TmaWarpSpecializedGroupedGemmInput StrideC* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t; void* ptr_final_output = nullptr; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion @@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); @@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput return stride_a != nullptr && ptr_a != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction); std::string toString() const; }; diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h index e16bc34a2e7..1bda2247ce6 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h @@ -426,7 +426,7 @@ class CutlassMoeFCRunnerInterface ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, bool use_lora, + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; @@ -450,8 +450,8 @@ class CutlassMoeFCRunnerInterface void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, @@ -468,7 +468,8 @@ class CutlassMoeFCRunnerInterface void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) = 0; virtual std::pair @@ -485,13 +486,13 @@ class CutlassMoeFCRunnerInterface virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template (final_output), expert_first_token_offset, tma_ws_input_template, static_cast(fc2_expert_weights), static_cast(fc2_expert_biases), static_cast(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, - token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row, - expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, - expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, - use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode, num_active_experts_per, - active_expert_global_ids, start_expert); + token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, + permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows, + hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, + stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids, + start_expert); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override @@ -673,7 +674,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, - void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override + void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, @@ -682,7 +684,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast(bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), stream); + reinterpret_cast(gemm2_output), router_scales, permuted_row_to_unpermuted_row, + stream); } std::pair @@ -724,7 +727,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row, + cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, @@ -754,8 +758,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4afp8; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_ + && !use_w4afp8; } // TODO: This should eventually take the quant params to give more flexibility @@ -791,7 +795,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, - float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row, + float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 6a96527252c..f1a6b9dc88a 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3767deac592a204493b09f6798d50269c90d4571971b1746a5e6d0009a6d6d65 -size 64229720 +oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d +size 66872936 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index 74e01ceecb9..4af58b0800e 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -f68113dae0236968594276bf4f8b0a6f9161d3fbbac6fcb9ea1a438d16055490 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 995030f9b86258f3db876df6b1dbc46a7c5dae50 +813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a +commit 9c0a42825905952beaf9b35d5a35d58de1a123fa diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 3d1d73a0d8c..189e23b8acb 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -334,12 +334,13 @@ void MixtureOfExpertsPlugin::init() static_cast(mType), static_cast(mWeightType), static_cast(mOutputType)); } - mMOERunner->use_deterministic_hopper_reduce_ = mExpertsPerToken > 2 && mUseDeterministicKernels; + mMOERunner->use_fused_finalize_ + = (mExpertsPerToken < 3 || !mUseDeterministicKernels) && !getEnvMOEDisableFinalizeFusion(); mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize, - mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_}; + mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_}; mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize, - mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_}; + mGroupSize, mActivationType, mType, mWeightType, mQuantMode, !mMOERunner->use_fused_finalize_}; mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mExpertsPerToken); if (hasLora()) diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 204be5b3766..299e302ec53 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -95,7 +95,8 @@ class FusedMoeRunner : public torch::CustomClassHolder }; FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype, - bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling) + bool use_deepseek_fp8_block_scale, bool use_w4_group_scaling, bool use_mxfp8_act_scaling, + bool use_fused_finalize) { mActivationDtype = activation_dtype; mWeightDtype = weight_dtype; @@ -103,6 +104,7 @@ class FusedMoeRunner : public torch::CustomClassHolder mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale; mUseW4GroupScaling = use_w4_group_scaling; mUseMxfp8ActScaling = use_mxfp8_act_scaling; + mUseFusedFinalize = use_fused_finalize; mInnerDimMultiplier = 1; // keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -213,6 +215,8 @@ class FusedMoeRunner : public torch::CustomClassHolder << ", Output: " << torch::toString(mOutputDtype)); } + mKernelRunner->use_fused_finalize_ = mUseFusedFinalize; + mProfiler = std::make_shared(); mAllProfiles = mKernelRunner->getTactics(); } @@ -674,6 +678,7 @@ class FusedMoeRunner : public torch::CustomClassHolder bool mUseDeepSeekFP8BlockScaling = false; bool mUseW4GroupScaling = false; bool mUseMxfp8ActScaling = false; + bool mUseFusedFinalize = true; using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector mAllProfiles; @@ -1045,7 +1050,7 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_LIBRARY(trtllm, m) { m.class_("FusedMoeRunner") - .def(torch::init()) + .def(torch::init()) .def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile) .def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum) .def("run_moe", &torch_ext::FusedMoeRunner::runMoe) diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index b615fca03c9..6f2ce0f93e6 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" @@ -355,7 +371,7 @@ protected: float mSparseMixerEpsilon = 0.2f; // Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths - bool mUseDeterminsiticHopperReduce = true; + bool mUseDeterministicHopperReduce = true; // Disable this for long running tests to speed up runtime bool mIsLongTest = false; @@ -440,7 +456,7 @@ protected: { managed_buffers.clear(); - mMoERunner.use_deterministic_hopper_reduce_ = k > 2 && mUseDeterminsiticHopperReduce; + mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce; mHiddenSize = hidden_size; mInterSize = hidden_size * mInterSizeFraction; @@ -1614,7 +1630,7 @@ void MixtureOfExpertsTest::BasicPermuteTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1733,7 +1749,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias) TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic) { - this->mUseDeterminsiticHopperReduce = false; + this->mUseDeterministicHopperReduce = false; // Just test case 3, cases 1&2 always use the fused paths this->BasicPermuteTest(3); } @@ -1881,7 +1897,7 @@ void MixtureOfExpertsTest::ParallelismTest( runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1897,7 +1913,7 @@ void MixtureOfExpertsTest::ParallelismTest( { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); bool should_be_deterministic - = mUseDeterminsiticHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); diff --git a/pyproject.toml b/pyproject.toml index b0e25b6ea93..edc6fbcf8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ ignore_patterns = [ [tool.codespell] skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json" exclude-file = "examples/models/core/whisper/tokenizer.py" -ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw" +ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw,dOut" [tool.autoflake] in-place = true diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 98f27fe6ea2..ba71e4fbfe3 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -458,7 +458,7 @@ def _( gemm2_output: torch.Tensor, fc2_expert_biases: torch.Tensor, unpermuted_final_scales: torch.Tensor, - expanded_source_row_to_expanded_dest_row: torch.Tensor, + unpermuted_row_to_permuted_row: torch.Tensor, expert_for_source_row: torch.Tensor, expert_first_token_offset_tensor: torch.Tensor, num_rows: torch.SymInt, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index b9746b070de..0ca269ad157 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -42,6 +42,7 @@ def __init__( use_w4_group_scaling: bool, use_mxfp8_act_scaling: bool, min_latency_mode: bool, + use_fused_finalize: bool, ): self.x_dtype = x_dtype self.weight_dtype = weight_dtype @@ -59,6 +60,8 @@ def __init__( self.use_w4_group_scaling = use_w4_group_scaling self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode + self.use_fused_finalize = use_fused_finalize + instance_key = (x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling) @@ -68,7 +71,7 @@ def __init__( instance_key] = torch.classes.trtllm.FusedMoeRunner( x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, - use_mxfp8_act_scaling) + use_mxfp8_act_scaling, use_fused_finalize) self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics( @@ -143,6 +146,7 @@ def fused_moe( use_w4_group_scaling: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, + use_fused_finalize: bool = True, tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, @@ -179,6 +183,7 @@ def fused_moe( use_w4_group_scaling=use_w4_group_scaling, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, ) _, gemm_tactic_1 = tuner.choose_one( @@ -259,6 +264,7 @@ def _( use_w4_group_scaling: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, + use_fused_finalize: bool = True, tune_max_num_tokens: int = 8192, ): seq_len = input.shape[0] diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 125a637a493..5bc9e7870f4 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -83,6 +83,9 @@ class ModelConfig(Generic[TConfig]): attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM + # IF true, disables FC2+finalize fusion in CUTLASS MoE backend + moe_disable_finalize_fusion: bool = False + allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO # If true, enable min-latency mode. Currently only used for Llama4. diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index f96c6e09f85..c8f54e57011 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -148,6 +148,8 @@ def __init__( # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input + self.use_fused_finalize = not model_config.moe_disable_finalize_fusion + self._weights_created = False if not model_config.skip_create_weights_in_init: self.create_weights() @@ -417,6 +419,7 @@ def forward_chunk( use_w4_group_scaling=use_w4_group_scaling, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=False, + use_fused_finalize=self.use_fused_finalize, tune_max_num_tokens=self.tune_max_num_tokens, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 14e57661b55..2307ac139f2 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -53,6 +53,8 @@ class PyTorchConfig: attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' + moe_disable_finalize_fusion: bool = False + enable_mixed_sampler: bool = False """ If true, will iterate over sampling_params of each request and use the diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index d39ddc4f2c1..9aa60f40517 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -293,6 +293,8 @@ def __init__( checkpoint_loader=checkpoint_loader, attn_backend=attn_backend, moe_backend=pytorch_backend_config.moe_backend, + moe_disable_finalize_fusion=pytorch_backend_config. + moe_disable_finalize_fusion, load_format=pytorch_backend_config.load_format, max_num_tokens=max_num_tokens, moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1169a779be6..71cba1b01e1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -183,6 +183,12 @@ class MoeConfig(StrictBaseModel): description="Configuration for MoE load balancing.", json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"}) + disable_finalize_fusion: bool = Field( + default=False, + description= + "Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2." + ) + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -2352,6 +2358,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker, load_format=self.load_format, enable_min_latency=self.enable_min_latency, + moe_disable_finalize_fusion=self.moe_config.disable_finalize_fusion, stream_interval=self.stream_interval, force_dynamic_quantization=self.force_dynamic_quantization, allreduce_strategy=self.allreduce_strategy, diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index b56caa264a8..cde502df450 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -1075,7 +1075,7 @@ def test_fused_moe_nvfp4(dtype): # compare torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15) @skip_neither_ada_nor_hopper_unittest @@ -1320,7 +1320,7 @@ def test_fused_moe_mxfp4_mxpf8(moe_backend, bias): # compare torch.cuda.synchronize() - torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15) @skip_non_hopper_unittest