From 3cfba8ec17db14f35e8fca1c2351d95854972cbc Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 8 Oct 2025 17:50:49 +0000 Subject: [PATCH 01/17] >launcher.inl --- .../launchers/moe_gemm_tma_ws_launcher.inl | 1012 +++++++++-------- 1 file changed, 538 insertions(+), 474 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index a3e4a87398..b0fc36069e 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -165,480 +165,544 @@ using SafeBF16 = void; #endif // TODO Revert this back to a template instantiation once compiler bug is resolved -#define INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM(ArchTag_, DataType_, WeightType_, OutputType_, \ - EpilogueTag_, FUSION_, CTA_M_, CTA_N_, CTA_K_, \ - CGA_M_, CGA_N_, CGA_K_, MXFPX_, BIAS_) \ - static void \ - tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##BIAS_( \ - TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ - size_t* workspace_size) { \ - constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ - /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ - using ArchTag = cutlass::arch::ArchTag_; \ - using T = DataType_; \ - using WeightType = WeightType_; \ - using OutputType = OutputType_; \ - using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ - using TileShape = cute::Shape, cute::Int, cute::Int>; \ - using ClusterShape = cute::Shape, cute::Int, cute::Int>; \ - constexpr static bool IsMXFPX = MXFPX_; \ - \ - if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 90 && ArchTag::kMinComputeCapability < 100) { \ - TLLM_THROW( \ - "Please recompile with support for hopper by passing 90-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!COMPILE_BLACKWELL_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 100 && \ - ArchTag::kMinComputeCapability < 120) { \ - TLLM_THROW( \ - "Please recompile with support for blackwell by passing 100-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS_ENABLED && \ - ArchTag::kMinComputeCapability >= 120) { \ - TLLM_THROW( \ - "Please recompile with support for blackwell by passing 120-real as an arch to " \ - "build_wheel.py."); \ - } else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< \ - ArchTag, TileShape, ClusterShape, T>) { \ - using namespace cute; \ - /* Helper class for defining all the cutlass types \ - // template \ - // struct TmaWarpSpecializedGroupedGemmInfo \ - { */ \ - using Arch = ArchTag; \ - constexpr static bool IsBlackwell = Arch::kMinComputeCapability >= 100; \ - constexpr static bool IsSM120 = \ - Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121; \ - constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value && \ - cutlass::platform::is_same::value; \ - constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ - static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ - \ - constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ - \ - /* TODO Update once mixed input support is added */ \ - static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ - "TMA warp specialized MOE implementation does not support mixed input types"); \ - \ - constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ - static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ - \ - static_assert(cutlass::platform::is_same::value || \ - cutlass::platform::is_same::value || \ - cutlass::platform::is_same::value || IsFP8 || IsFP4, \ - "Specialized for bfloat16, half, float, fp8, fp4"); \ - \ - /* The cutlass type for the input elements. This is needed to convert to cutlass::half_t if \ - * necessary.*/ \ - using ElementType = typename TllmToCutlassTypeAdapter::type; \ - \ - /* TODO The below never trigger, and are incorrect for int8 types anyway \ - // using CutlassWeightTypeMaybeUint4 = typename \ - TllmToCutlassTypeAdapter::type; \ - // // For legacy reasons we convert unsigned 8-bit to signed \ - // using CutlassWeightTypeMaybeUint8 \ - // = std::conditional_t, cutlass::int4b_t, \ - // CutlassWeightTypeMaybeUint4>; \ - // using CutlassWeightType \ - // = std::conditional_t, int8_t, \ - // CutlassWeightTypeMaybeUint8>; */ \ - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; \ - \ - using ElementA = ElementType; \ - using ElementB = CutlassWeightType; \ - \ - using ElementD = typename TllmToCutlassTypeAdapter< \ - TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t>::type; \ - using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; \ - \ - /* using ElementC = std::conditional_t; */ \ - /* using ElementCSafe = std::conditional_t; */ \ - using ElementC = void; \ - using ElementCSafe = ElementD; \ - \ - using ElementAccumulator = float; \ - \ - using ElementBias = ElementFinalOutput; \ - using ElementRouterScales = float; \ - \ - using ElementSF = std::conditional_t< \ - IsMXFPX, cutlass::float_ue8m0_t, \ - cutlass::float_ue4m3_t>; /*TmaWarpSpecializedGroupedGemmInput::ElementSF;*/ \ - using ElementABlockScaled = std::conditional_t, \ - cute::tuple>; \ - using ElementBBlockScaled = std::conditional_t, \ - cute::tuple>; \ - \ - /* A matrix configuration - this is transposed and swapped with B */ \ - using LayoutA = TmaWarpSpecializedGroupedGemmInput::LayoutA; \ - constexpr static int AlignmentA = \ - 128 / \ - cutlass::sizeof_bits::value; /* Memory access granularity/alignment of A \ - matrix in units of elements (up to 16 bytes) */ \ - /* B matrix configuration - this is transposed and swapped with A */ \ - using LayoutB = \ - TmaWarpSpecializedGroupedGemmInput::LayoutB; /* Layout type for B matrix operand */ \ - constexpr static int AlignmentB = \ - IsWFP4AFP8 \ - ? 128 \ - : (128 / \ - cutlass::sizeof_bits::value); /* Memory access granularity/alignment of \ - B matrix in units \ - // of elements (up to 16 bytes)*/ \ - \ - /* C matrix configuration */ \ - using LayoutC = \ - TmaWarpSpecializedGroupedGemmInput::LayoutC; /* Layout type for C matrix operand */ \ - using StrideC = TmaWarpSpecializedGroupedGemmInput::StrideC; \ - /* Note we use ElementType here deliberately, so we don't break when BIAS is disabled */ \ - constexpr static int AlignmentC = \ - 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment \ - of C matrix in \ - // units of elements (up to 16 bytes)*/ \ - \ - /* D matrix configuration */ \ - using LayoutD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::LayoutD; \ - using StrideD = TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD; \ - constexpr static int AlignmentD = \ - 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment of D \ - matrix \ - // in units of elements (up to 16 bytes) */ \ - \ - static_assert( \ - cutlass::platform::is_same::value, \ - "TMA Warp Specialized Grouped GEMM specialisation doesn't support fused activation"); \ - \ - using EpilogueOp = \ - cutlass::epilogue::fusion::LinearCombination; \ - \ - /* TODO Add mode for fused activation once CUTLASS adds support \ - // using EpilogueSchedule = cutlass::platform::conditional_t< \ - // cutlass::platform::is_same::value, \ - // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ - // cutlass::epilogue::?????????????????? /// <<<<<< what supports \ - activations \ - // >;*/ \ - using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; \ - \ - constexpr static bool Is2SM = IsBlackwell && (cute::size<0>(ClusterShape{}) % 2) == 0; \ - using EpilogueScheduleSM100 = \ - std::conditional_t; \ - using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ - using EpilogueScheduleBW = \ - std ::conditional_t; \ - using EpilogueSchedule = \ - std::conditional_t; \ - \ - using EpilogueTileShapeSm90 = TileShape; \ - using AtomClusterDiv = std::conditional_t; \ - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape{})); \ - using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using EpilogueTileShape = \ - std::conditional_t; \ - using EpilogueElementC = std::conditional_t; \ - using EpilogueTensorOp = std::conditional_t; \ - using EpilogueSubTile = std::conditional_t< \ - Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \ - cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto>; \ - /* Epilogue For Default Finalize */ \ - using CollectiveEpilogueDefault = typename cutlass::epilogue::collective:: \ - CollectiveBuilder< /**/ \ - Arch, EpilogueTensorOp, /**/ \ - EpilogueTileShape, ClusterShape, /**/ \ - EpilogueSubTile, /**/ \ - ElementAccumulator, ElementAccumulator, /**/ \ - EpilogueElementC, LayoutC*, AlignmentC, /**/ \ - ElementD, LayoutD*, AlignmentD, /**/ \ - EpilogueSchedule>::CollectiveOp; \ - \ - /* Epilogue For Fused Finalize */ \ - using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective:: \ - EpilogueMoeFusedFinalizeBuilder< /**/ \ - Arch, EpilogueTileShape, /**/ \ - ElementCSafe, StrideC*, /**/ \ - ElementFinalOutput, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideFinalOutput, /**/ \ - ElementAccumulator, /**/ \ - ElementAccumulator, /**/ \ - ElementBias, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideBias, /**/ \ - ElementRouterScales, \ - TmaWarpSpecializedGroupedGemmInput:: \ - FusedFinalizeEpilogue::StrideRouterScales /**/ \ - >::CollectiveOp; \ - \ - using CollectiveEpilogue = \ - std::conditional_t; \ - \ - using StageCountAutoCarveout = \ - cutlass::gemm::collective::StageCountAutoCarveout( \ - sizeof(typename CollectiveEpilogue::SharedStorage))>; \ - \ - using KernelScheduleSM90 = std::conditional_t< \ - IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, \ - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; \ - \ - using KernelSchedule2SmSm100BlockScaled = \ - std::conditional_t; \ - using KernelSchedule1SmSm100BlockScaled = \ - std::conditional_t; \ - \ - /* TRT-LLM uses vector size 16 for block scaled */ \ - using KernelScheduleSM100 = std::conditional_t< \ - Is2SM, \ - std::conditional_t, \ - std::conditional_t>; \ - using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \ - using KernelScheduleBW = \ - std::conditional_t; \ - \ - using KernelSchedule = \ - std::conditional_t; \ - \ - using TensorOp = std::conditional_t; \ - \ - using MainloopElementA = \ - std::conditional_t; \ - using MainloopElementB = \ - std::conditional_t; \ - \ - using MainloopTileShapeSm90 = TileShape; \ - using MainloopTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \ - using MainloopTileShape = \ - std::conditional_t; \ - \ - using CollectiveMainloop = typename cutlass::gemm::collective:: \ - CollectiveBuilder< /**/ \ - Arch, TensorOp, /**/ \ - MainloopElementB, LayoutB*, AlignmentB, /* A & B swapped here */ \ - MainloopElementA, LayoutA*, AlignmentA, /**/ \ - ElementAccumulator, /**/ \ - MainloopTileShape, ClusterShape, /**/ \ - StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \ - \ - using GemmKernel = \ - cutlass::gemm::kernel::GemmUniversal; \ - \ - using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; \ - /*}; \ - \ \ - // using namespace cute; \ - // using GemmInfo = TmaWarpSpecializedGroupedGemmInfo;; \ - // \ - // using ElementAccumulator = typename GemmInfo::ElementAccumulator; \ - // using ElementA = typename GemmInfo::ElementA; \ - // using ElementB = typename GemmInfo::ElementB; \ - // using ElementC = typename GemmInfo::ElementC; \ - // using ElementCSafe = typename GemmInfo::ElementCSafe; \ - // using ElementD = typename GemmInfo::ElementD; \ - // using ElementFinalOutput = typename GemmInfo::ElementFinalOutput; \ - // using ElementBias = typename GemmInfo::ElementBias; \ - // \ - // using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; \ - // using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; \ - // using GemmKernel = typename GemmInfo::GemmKernel; \ - // using GemmGrouped = typename GemmInfo::GemmGrouped;*/ \ - \ - if (kernel_occupancy != nullptr) { \ - TLLM_THROW("TMA WS kernels do not support calculating occupancy"); \ - return; \ - } \ - \ - cutlass::KernelHardwareInfo hw_info; \ - hw_info.device_id = 0; \ - hw_info.sm_count = multi_processor_count; \ - \ - GemmGrouped gemm; \ - \ - if (workspace_size != nullptr) { \ - /* Make a mock problem shape with just the minimal information actually required to get \ - the workspace \ - // size This makes some assumptions about CUTLASS's implementation which is suboptimal. We \ - have a check \ - // later to catch future cutlass updates causing silent breakages, but that is not fool \ - proof. The \ - // alternative is to wait until we have data and then dynamically allocate the workspace*/ \ - typename TmaWarpSpecializedGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, \ - nullptr}; \ - \ - typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ - const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - shape_info, \ - {}, \ - {}, \ - hw_info, \ - scheduler_args}; \ - *workspace_size = gemm.get_workspace_size(args); \ - return; \ - } \ - \ - using MainloopArguments = typename CollectiveMainloop::Arguments; \ - TLLM_CHECK(tma_ws_input.stride_a); \ - TLLM_CHECK(tma_ws_input.stride_b); \ - TLLM_CHECK(tma_ws_input.ptr_a); \ - TLLM_CHECK(tma_ws_input.ptr_b); \ - \ - auto make_mainloop_params = [&]() -> MainloopArguments { \ - if constexpr (IsBlockScaled) { \ - return construct_if_true( \ - reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ - reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a, \ - reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_B), \ - reinterpret_cast())>( \ - tma_ws_input.fpX_block_scaling_factors_stride_B), \ - reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_A), \ - reinterpret_cast())>( \ - tma_ws_input.fpX_block_scaling_factors_stride_A)); \ - } else { \ - return construct_if_true( \ - reinterpret_cast(tma_ws_input.ptr_b), tma_ws_input.stride_b, \ - reinterpret_cast(tma_ws_input.ptr_a), tma_ws_input.stride_a); \ - } \ - }; \ - \ - auto const mainloop_params = make_mainloop_params(); \ - \ - using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ - using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ - auto make_epilogue_scalars = [&]() { \ - if constexpr (IsBlackwell) { \ - return construct_if_true( \ - ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, \ - nullptr, tma_ws_input.alpha_scale_ptr_array, nullptr, \ - cute::Shape<_0, _0, int64_t>{ \ - cute::_0{}, cute::_0{}, \ - (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ - cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ - } else if (tma_ws_input.alpha_scale_ptr_array) { \ - return construct_if_true( \ - tma_ws_input.alpha_scale_ptr_array); \ - } else { \ - return construct_if_true( \ - ElementAccumulator(1.f), \ - tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ - } \ - }; \ - auto epilogue_scalars = make_epilogue_scalars(); \ - /* TODO ptr_c casts to ElementCSafe** because there is a workaround in CUTLASS */ \ - auto make_epi_args = [&]() { \ - static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ - "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ - \ - if constexpr (FUSION == EpilogueFusion::NONE) { \ - auto epi_params = tma_ws_input.default_epilogue; \ - return construct_if_true < FUSION == EpilogueFusion::NONE, \ - EpilogueArguments > (epilogue_scalars, nullptr, tma_ws_input.stride_c, \ - reinterpret_cast(epi_params.ptr_d), \ - epi_params.stride_d); \ - } else if constexpr (FUSION == EpilogueFusion::FINALIZE) { \ - /* Parameters for fused finalize */ \ - auto epi_params = tma_ws_input.fused_finalize_epilogue; \ - return construct_if_true < FUSION == EpilogueFusion::FINALIZE, \ - EpilogueArguments > \ - (epilogue_scalars, /* Parameters to underlying epilogue */ \ - nullptr, tma_ws_input.stride_c, /* C params */ \ - reinterpret_cast(epi_params.ptr_final_output), \ - epi_params.stride_final_output, /* D (output) params */ \ - reinterpret_cast(epi_params.ptr_bias), \ - epi_params.stride_bias, /* Bias params */ \ - epi_params.ptr_router_scales, \ - epi_params.stride_router_scales, /* Router scales */ \ - epi_params.ptr_expert_first_token_offset, /* Offset of this expert's token \ - in the router scales */ \ - epi_params \ - .ptr_source_token_index, /* Index of the source token to sum into */ \ - epi_params \ - .num_rows_in_final_output /* Number of tokens in the output buffer */ \ - ); \ - } \ - }; \ - EpilogueArguments const epilogue_params = make_epi_args(); \ - /* EpilogueArguments const epilogue_params = make_epi_args( \ - // tma_ws_input, epilogue_scalars \ - // );*/ \ - \ - typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ - \ - const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ - tma_ws_input.shape_info, \ - mainloop_params, \ - epilogue_params, \ - hw_info, \ - scheduler_args}; \ - \ - size_t calculated_ws_size = gemm.get_workspace_size(args); \ - TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ - "Workspace is size %zu but only %zu were allocated", \ - calculated_ws_size, tma_ws_input.gemm_workspace_size); \ - \ - auto can_implement = gemm.can_implement(args); \ - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, \ - "Grouped GEMM kernel will fail for params. Error: " + \ - std::string(cutlass::cutlassGetStatusString(can_implement))); \ - \ - auto init_status = gemm.initialize(args, tma_ws_input.gemm_workspace); \ - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ - "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ - std::string(cutlass::cutlassGetStatusString(init_status))); \ - auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ - "Failed to run cutlass TMA WS grouped gemm. Error: " + \ - std::string(cutlass::cutlassGetStatusString(run_status))); \ - sync_check_cuda_error(stream); \ - } else { \ - TLLM_THROW("Configuration was disabled by FAST_BUILD"); \ - } \ - \ - return; \ - } \ - \ - template <> \ - struct DispatchToTmaWSFunction< \ - cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, \ - tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ - cute::Shape, cute::Int, cute::Int>, \ - cute::Shape, cute::Int, cute::Int>, MXFPX_, BIAS_> { \ - constexpr static auto* op = \ - &tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##BIAS_; \ - }; \ - template void tma_warp_specialized_generic_moe_gemm_kernelLauncher< \ - cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, \ - tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ - cute::Shape, cute::Int, cute::Int>, \ - cute::Shape, cute::Int, cute::Int>, MXFPX_, BIAS_>( \ - TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ - size_t* workspace_size); +#define INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM( \ + ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, EpilogueTag_, FUSION_, \ + CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, CGA_K_, MXFPX_, DYNAMIC_CGA_, BIAS_, SWAP_AB_) \ + static void \ + tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueSchedule_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##DYNAMIC_CGA_##_##BIAS_##_##SWAP_AB_( \ + TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ + int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, \ + size_t* workspace_size, cute::Shape dynamic_cluster_shape, \ + cute::Shape fallback_cluster_shape) { \ + using ArchTag = cutlass::arch::ArchTag_; \ + constexpr static EpilogueFusion FUSION = EpilogueFusion::FUSION_; \ + constexpr static bool IsMXFPX = MXFPX_; \ + constexpr static bool DYNAMIC_CGA = DYNAMIC_CGA_; \ + constexpr static bool SwapAB = SWAP_AB_; \ + constexpr bool IsBlackwell = ArchTag::kMinComputeCapability >= 100; \ + constexpr static bool IsSM10x = \ + ArchTag::kMinComputeCapability >= 100 && ArchTag::kMinComputeCapability < 120; \ + constexpr static bool IsSM103 = ArchTag::kMinComputeCapability == 103; \ + constexpr bool IsSM120 = \ + ArchTag::kMinComputeCapability == 120 || ArchTag::kMinComputeCapability == 121; \ + /* constexpr static bool BIAS = BIAS_; */ /* Always false */ \ + using T = DataType_; \ + using WeightType = WeightType_; \ + using OutputType = OutputType_; \ + using EpilogueTag = tensorrt_llm::cutlass_extensions::EpilogueTag_; \ + using InputClusterShape = \ + cute::Shape, cute::Int, cute::Int>; \ + constexpr static bool Is2SM = IsSM10x && cute::size<0>(InputClusterShape{}) == 2; \ + using ClusterShape = std::conditional_t, \ + InputClusterShape>; \ + using MmaTileShape = cute::Shape, cute::Int, \ + cute::Int>; \ + using InputEpilogueSchedule = EpilogueSchedule_; \ + if constexpr (!COMPILE_HOPPER_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 90 && ArchTag::kMinComputeCapability < 100) { \ + TLLM_THROW( \ + "Please recompile with support for hopper by passing 90-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!COMPILE_BLACKWELL_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 100 && \ + ArchTag::kMinComputeCapability < 120) { \ + TLLM_THROW( \ + "Please recompile with support for blackwell by passing 100-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS_ENABLED && \ + ArchTag::kMinComputeCapability >= 120) { \ + TLLM_THROW( \ + "Please recompile with support for blackwell by passing 120-real as an arch to " \ + "build_wheel.py."); \ + } else if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< \ + ArchTag, MmaTileShape, ClusterShape, DYNAMIC_CGA, T>) { \ + TLLM_CHECK_WITH_INFO(SwapAB == tma_ws_input.swap_ab, "SwapAB must match runtime swap_ab"); \ + using namespace cute; \ + /* Helper class for defining all the cutlass types \ + // template \ + // struct TmaWarpSpecializedGroupedGemmInfo \ + { */ \ + constexpr static bool IsWFP4AFP8 = cutlass::platform::is_same::value && \ + cutlass::platform::is_same::value; \ + constexpr static bool IsFP4 = cutlass::platform::is_same::value; \ + static_assert(!IsFP4 || IsBlackwell, "FP4 is only supported by SM100"); \ + \ + constexpr static bool IsFP8 = cutlass::platform::is_same::value; \ + \ + /* TODO Update once mixed input support is added */ \ + static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ + "TMA warp specialized MOE implementation does not support mixed input types"); \ + \ + constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ + static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ + \ + static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ + "Unimplemented fusion provided to TMA WS MoE gemm launcher"); \ + constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE; \ + constexpr bool IsTmaSM10xEpilogue = \ + std::is_same_v; \ + \ + static_assert(cutlass::platform::is_same::value || \ + cutlass::platform::is_same::value || \ + cutlass::platform::is_same::value || IsFP8 || IsFP4, \ + "Specialized for bfloat16, half, float, fp8, fp4"); \ + \ + /* The cutlass type for the input elements. This is needed to convert to cutlass::half_t if \ + * necessary.*/ \ + using ElementType = typename TllmToCutlassTypeAdapter::type; \ + \ + /* TODO The below never trigger, and are incorrect for int8 types anyway \ + // using CutlassWeightTypeMaybeUint4 = typename \ + TllmToCutlassTypeAdapter::type; \ + // // For legacy reasons we convert unsigned 8-bit to signed \ + // using CutlassWeightTypeMaybeUint8 \ + // = std::conditional_t, cutlass::int4b_t, \ + // CutlassWeightTypeMaybeUint4>; \ + // using CutlassWeightType \ + // = std::conditional_t, int8_t, \ + // CutlassWeightTypeMaybeUint8>; */ \ + using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; \ + \ + using ElementAct = ElementType; \ + using ElementWeight = CutlassWeightType; \ + \ + using ElementD = typename TllmToCutlassTypeAdapter< \ + TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t>::type; \ + using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; \ + \ + /* using ElementC = std::conditional_t; */ \ + /* using ElementCSafe = std::conditional_t; */ \ + using ElementC = void; \ + using ElementCSafe = ElementD; \ + \ + using ElementAccumulator = float; \ + \ + using ElementBias = ElementFinalOutput; \ + using ElementRouterScales = float; \ + \ + using ElementSF = std::conditional_t< \ + IsMXFPX, cutlass::float_ue8m0_t, \ + cutlass::float_ue4m3_t>; /*TmaWarpSpecializedGroupedGemmInput::ElementSF;*/ \ + using ElementActBlockScaled = std::conditional_t, \ + cute::tuple>; \ + using ElementWeightBlockScaled = \ + std::conditional_t, \ + cute::tuple>; \ + \ + /* Activation matrix alignment */ \ + constexpr static int AlignmentAct = \ + 128 / \ + cutlass::sizeof_bits::value; /* Memory access granularity/alignment of A \ + matrix in units of elements (up to 16 bytes) */ \ + /* Weight matrix alignment */ \ + constexpr static int AlignmentWeight = \ + IsWFP4AFP8 \ + ? 128 \ + : (128 / \ + cutlass::sizeof_bits::value); /* Memory access \ + granularity/alignment of B matrix in units \ + // of elements (up to 16 bytes)*/ \ + \ + /* C matrix configuration */ \ + /* Note we use ElementType here deliberately, so we don't break when BIAS is disabled */ \ + constexpr static int AlignmentC = \ + 128 / cutlass::sizeof_bits::value; /* Memory access granularity/alignment \ + of C matrix in \ + // units of elements (up to 16 bytes)*/ \ + \ + /* D matrix configuration */ \ + constexpr static int AlignmentDBits = \ + (IsSM10x && !IsTmaSM10xEpilogue) \ + ? 256 \ + : 128; /* For NoSmem epilogue schedule, we need to align to 256 bits */ \ + constexpr static int AlignmentD = \ + AlignmentDBits / cutlass::sizeof_bits::value; /* Memory access \ + granularity/alignment of D matrix \ + // in units of elements (up to 16 bytes) */ \ + \ + static_assert( \ + cutlass::platform::is_same::value, \ + "TMA Warp Specialized Grouped GEMM specialisation doesn't support fused activation"); \ + \ + using EpilogueOp = \ + cutlass::epilogue::fusion::LinearCombination; \ + \ + /* TODO Add mode for fused activation once CUTLASS adds support \ + // using EpilogueSchedule = cutlass::platform::conditional_t< \ + // cutlass::platform::is_same::value, \ + // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, \ + // cutlass::epilogue::?????????????????? /// <<<<<< what supports \ + activations \ + // >;*/ \ + using EpilogueScheduleSM90 = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ + \ + using EpilogueScheduleSM10x = std::conditional_t< \ + IsTmaSM10xEpilogue, \ + std::conditional_t, \ + std::conditional_t>; \ + using EpilogueScheduleSM120 = cutlass::epilogue::TmaWarpSpecialized; \ + using EpilogueSchedule = std::conditional_t< \ + IsSM10x, EpilogueScheduleSM10x, \ + std::conditional_t>; \ + using EpilogueElementC = std::conditional_t; \ + using EpilogueTensorOp = std::conditional_t; \ + using EpilogueScheduleSM10xFinalize = std::conditional_t< \ + !IsFinalizeFusion && IsSM10x, \ + std::conditional_t, \ + EpilogueSchedule>; /* This still needs to be valid when finalize fusion is disabled */ \ + \ + using EpilogueSubTile = std::conditional_t< \ + ArchTag::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \ + cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto>; \ + \ + using LayoutC = std::conditional_t; \ + using StrideC = std::conditional_t; \ + using LayoutD = std::conditional_t; \ + using StrideD = std::conditional_t; \ + \ + /* Epilogue For Default Finalize */ \ + using CollectiveEpilogueDefault = typename cutlass::epilogue::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, EpilogueTensorOp, /**/ \ + MmaTileShape, ClusterShape, /**/ \ + EpilogueSubTile, /**/ \ + ElementAccumulator, ElementAccumulator, /**/ \ + EpilogueElementC, LayoutC*, AlignmentC, /**/ \ + ElementD, LayoutD*, AlignmentD, /**/ \ + EpilogueSchedule>::CollectiveOp; \ + \ + /* Epilogue For Fused Finalize */ \ + using EpilogueFusionOp = std::conditional_t< \ + SwapAB, \ + cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< \ + LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales>, \ + cutlass::epilogue::fusion::ScaledAccPerColBiasPerRowScaleScatter< \ + LayoutD, ElementFinalOutput, ElementAccumulator, ElementBias, ElementRouterScales>>; \ + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, EpilogueTensorOp, /**/ \ + MmaTileShape, InputClusterShape, /**/ \ + EpilogueSubTile, /**/ \ + ElementAccumulator, ElementAccumulator, /**/ \ + EpilogueElementC, LayoutC*, AlignmentC, /**/ \ + void, LayoutD*, AlignmentD, /**/ \ + EpilogueScheduleSM10xFinalize, /**/ \ + EpilogueFusionOp /**/ \ + >::CollectiveOp; \ + \ + using CollectiveEpilogue = std::conditional_t; \ + \ + using StageCountAutoCarveout = \ + cutlass::gemm::collective::StageCountAutoCarveout( \ + sizeof(typename CollectiveEpilogue::SharedStorage))>; \ + \ + using KernelScheduleSM90 = std::conditional_t< \ + IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; \ + \ + using KernelSchedule2SmSm100BlockScaled = \ + std::conditional_t; \ + using KernelSchedule1SmSm100BlockScaled = \ + std::conditional_t; \ + \ + /* TRT-LLM uses vector size 16 for block scaled */ \ + using KernelScheduleSM100 = std::conditional_t< \ + Is2SM, \ + std::conditional_t, \ + std::conditional_t>; \ + using KernelScheduleSM103 = std::conditional_t< \ + Is2SM, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103, \ + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103>; \ + using KernelScheduleSM10x = \ + std::conditional_t; \ + using KernelScheduleSM120 = cutlass ::gemm ::collective::KernelScheduleAuto; \ + using KernelScheduleBW = \ + std::conditional_t; \ + \ + using KernelSchedule = \ + std::conditional_t; \ + \ + using TensorOp = std::conditional_t; \ + \ + using MainloopElementAct = \ + std::conditional_t; \ + using MainloopElementWeight = std::conditional_t; \ + using SwappedMainloopElementA = \ + std::conditional_t; \ + using SwappedMainloopElementB = \ + std::conditional_t; \ + constexpr auto SwappedAlignmentA = SwapAB ? AlignmentWeight : AlignmentAct; \ + constexpr auto SwappedAlignmentB = SwapAB ? AlignmentAct : AlignmentWeight; \ + using LayoutA = TmaWarpSpecializedGroupedGemmInput::LayoutA; \ + using LayoutB = TmaWarpSpecializedGroupedGemmInput::LayoutB; \ + using StrideA = typename TmaWarpSpecializedGroupedGemmInput::StrideA; \ + using StrideB = typename TmaWarpSpecializedGroupedGemmInput::StrideB; \ + using CollectiveMainloop = typename cutlass::gemm::collective:: \ + CollectiveBuilder< /**/ \ + ArchTag, TensorOp, /**/ \ + SwappedMainloopElementA, LayoutA*, SwappedAlignmentA, /**/ \ + SwappedMainloopElementB, LayoutB*, SwappedAlignmentB, /**/ \ + ElementAccumulator, /**/ \ + MmaTileShape, ClusterShape, /**/ \ + StageCountAutoCarveout, KernelSchedule>::CollectiveOp; \ + \ + using GemmKernel = \ + cutlass::gemm::kernel::GemmUniversal; \ + \ + using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; \ + \ + if (kernel_occupancy != nullptr) { \ + TLLM_THROW("TMA WS kernels do not support calculating occupancy"); \ + return; \ + } \ + \ + cutlass::KernelHardwareInfo hw_info; \ + hw_info.device_id = 0; \ + hw_info.sm_count = multi_processor_count; \ + \ + if constexpr (DYNAMIC_CGA) { \ + TLLM_CHECK(cute::size<0>(dynamic_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<1>(dynamic_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<0>(fallback_cluster_shape) >= 1); \ + TLLM_CHECK(cute::size<1>(fallback_cluster_shape) >= 1); \ + TLLM_CHECK_WITH_INFO( \ + cute::size<0>(dynamic_cluster_shape) % cute::size<0>(fallback_cluster_shape) == 0, \ + "Dynamic cluster shape (%dx%d) must be divisible by cluster shape (%dx%d)", \ + (int)cute::size<0>(dynamic_cluster_shape), (int)cute::size<1>(dynamic_cluster_shape), \ + (int)cute::size<0>(fallback_cluster_shape), \ + (int)cute::size<1>(fallback_cluster_shape)); \ + TLLM_CHECK_WITH_INFO( \ + cute::size<0>(fallback_cluster_shape) % cute::size<0>(InputClusterShape{}) == 0, \ + "Fallback cluster shape (%dx%d) must be divisible by MMA cluster shape (%dx%d)", \ + (int)cute::size<0>(fallback_cluster_shape), \ + (int)cute::size<1>(fallback_cluster_shape), (int)cute::size<0>(InputClusterShape{}), \ + (int)cute::size<1>(InputClusterShape{})); \ + hw_info.cluster_shape = \ + dim3(cute::size<0>(dynamic_cluster_shape), cute::size<1>(dynamic_cluster_shape), 1); \ + hw_info.cluster_shape_fallback = \ + dim3(cute::size<0>(fallback_cluster_shape), cute::size<1>(fallback_cluster_shape), 1); \ + } \ + GemmGrouped gemm; \ + \ + if (workspace_size != nullptr) { \ + /* Make a mock problem shape with just the minimal information actually required to get \ + the workspace \ + // size This makes some assumptions about CUTLASS's implementation which is suboptimal. We \ + have a check \ + // later to catch future cutlass updates causing silent breakages, but that is not fool \ + proof. The \ + // alternative is to wait until we have data and then dynamically allocate the workspace*/ \ + typename TmaWarpSpecializedGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, \ + nullptr}; \ + \ + typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ + const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ + shape_info, \ + {}, \ + {}, \ + hw_info, \ + scheduler_args}; \ + *workspace_size = gemm.get_workspace_size(args); \ + return; \ + } \ + \ + using MainloopArguments = typename CollectiveMainloop::Arguments; \ + TLLM_CHECK(tma_ws_input.stride_act); \ + TLLM_CHECK(tma_ws_input.stride_weight); \ + TLLM_CHECK(tma_ws_input.ptr_act); \ + TLLM_CHECK(tma_ws_input.ptr_weight); \ + \ + MainloopArguments const mainloop_args = [&] { \ + if constexpr (IsBlockScaled) { \ + if constexpr (SwapAB) { \ + return construct_if_true<(IsBlockScaled && SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast( \ + tma_ws_input.fpX_block_scaling_factors_weight), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_weight), \ + reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_act), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_act)); \ + } else { \ + return construct_if_true<(IsBlockScaled && !SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.fpX_block_scaling_factors_act), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_act), \ + reinterpret_cast( \ + tma_ws_input.fpX_block_scaling_factors_weight), \ + reinterpret_cast())>( \ + tma_ws_input.fpX_block_scaling_factors_stride_weight)); \ + } \ + } else { \ + if constexpr (SwapAB) { \ + return construct_if_true<(!IsBlockScaled && SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight), \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act)); \ + } else { \ + return construct_if_true<(!IsBlockScaled && !SwapAB), MainloopArguments>( \ + reinterpret_cast(tma_ws_input.ptr_act), \ + reinterpret_cast(tma_ws_input.stride_act), \ + reinterpret_cast(tma_ws_input.ptr_weight), \ + reinterpret_cast(tma_ws_input.stride_weight)); \ + } \ + } \ + }(); \ + using EpilogueArguments = typename CollectiveEpilogue::Arguments; \ + using EpilogueScalars = decltype(EpilogueArguments{}.thread); \ + EpilogueScalars epilogue_scalars = [&] { \ + constexpr bool IsSimpleAlphaBeta = \ + std::is_constructible_v; \ + if constexpr (IsFinalizeFusion) { \ + auto epi_params = tma_ws_input.fused_finalize_epilogue; \ + if constexpr (SwapAB) { \ + return construct_if_true<(FUSION == EpilogueFusion::FINALIZE && SwapAB), \ + EpilogueScalars>( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_1, _0, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output_transposed, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.shape_override, \ + epi_params.use_reduction); \ + } else { \ + return construct_if_true<(FUSION == EpilogueFusion::FINALIZE && !SwapAB), \ + EpilogueScalars>( \ + ElementAccumulator(1), nullptr, tma_ws_input.alpha_scale_ptr_array, \ + Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */ \ + reinterpret_cast(epi_params.ptr_bias), \ + Stride<_0, _1, int64_t>{}, /* bias */ \ + epi_params.ptr_router_scales, Stride<_1, _0, int64_t>{}, /* scale */ \ + reinterpret_cast(epi_params.ptr_final_output), \ + epi_params.stride_final_output, epi_params.ptr_source_token_index, \ + epi_params.num_rows_in_final_output, epi_params.shape_override, \ + epi_params.use_reduction); \ + } \ + } else if constexpr (!IsSimpleAlphaBeta) { \ + return construct_if_true<(!IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + ElementAccumulator(1.f), \ + tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f), nullptr, \ + nullptr, tma_ws_input.alpha_scale_ptr_array, nullptr, \ + cute::Shape<_0, _0, int64_t>{ \ + cute::_0{}, cute::_0{}, \ + (tma_ws_input.alpha_scale_ptr_array != nullptr) ? 1 : 0}, \ + cute::Shape<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 0}); \ + } else if (tma_ws_input.alpha_scale_ptr_array) { \ + return construct_if_true<(IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + tma_ws_input.alpha_scale_ptr_array); \ + } else { \ + return construct_if_true<(IsSimpleAlphaBeta && !IsFinalizeFusion), EpilogueScalars>( \ + ElementAccumulator(1.f), \ + tma_ws_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); \ + } \ + }(); \ + \ + EpilogueArguments epilogue_args = [&] { \ + if constexpr (FUSION == EpilogueFusion::FINALIZE) { \ + return construct_if_true < FUSION == EpilogueFusion::FINALIZE, \ + EpilogueArguments > (epilogue_scalars, nullptr, nullptr, nullptr, nullptr); \ + } else { \ + return construct_if_true < FUSION != EpilogueFusion::FINALIZE, \ + EpilogueArguments > (epilogue_scalars, nullptr, nullptr, \ + reinterpret_cast(tma_ws_input.ptr_d), \ + reinterpret_cast(tma_ws_input.stride_d)); \ + } \ + }(); \ + \ + typename GemmKernel::TileScheduler::Arguments scheduler_args{ \ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; \ + \ + const typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, \ + tma_ws_input.shape_info, \ + mainloop_args, \ + epilogue_args, \ + hw_info, \ + scheduler_args}; \ + \ + size_t calculated_ws_size = gemm.get_workspace_size(args); \ + TLLM_CHECK_WITH_INFO(calculated_ws_size <= tma_ws_input.gemm_workspace_size, \ + "Workspace is size %zu but only %zu were allocated", \ + calculated_ws_size, tma_ws_input.gemm_workspace_size); \ + \ + auto can_implement = gemm.can_implement(args); \ + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, \ + "Grouped GEMM kernel will fail for params. Error: " + \ + std::string(cutlass::cutlassGetStatusString(can_implement))); \ + \ + auto init_status = gemm.initialize(args, tma_ws_input.gemm_workspace); \ + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ + "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ + std::string(cutlass::cutlassGetStatusString(init_status))); \ + auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ + "Failed to run cutlass TMA WS grouped gemm. Error: " + \ + std::string(cutlass::cutlassGetStatusString(run_status))); \ + sync_check_cuda_error(stream); \ + } else { \ + TLLM_THROW("Configuration was disabled by FAST_BUILD"); \ + } \ + \ + return; \ + } \ + \ + template <> \ + struct DispatchToTmaWSFunction< \ + cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, \ + tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ + cute::Shape, cute::Int, cute::Int>, \ + cute::Shape, cute::Int, cute::Int>, MXFPX_, DYNAMIC_CGA_, \ + BIAS_, SWAP_AB_> { \ + constexpr static auto* op = &tma_warp_specialized_generic_moe_gemm_kernelLauncher_##ArchTag_##_##DataType_##_##WeightType_##_##OutputType_##_##EpilogueSchedule_##_##EpilogueTag_##_##FUSION_##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##_##MXFPX_##_##DYNAMIC_CGA_##_##BIAS_##_##SWAP_AB_; \ + }; \ + template void tma_warp_specialized_generic_moe_gemm_kernelLauncher< \ + cutlass::arch::ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, \ + tensorrt_llm::cutlass_extensions::EpilogueTag_, EpilogueFusion::FUSION_, \ + cute::Shape, cute::Int, cute::Int>, \ + cute::Shape, cute::Int, cute::Int>, MXFPX_, DYNAMIC_CGA_, \ + BIAS_, SWAP_AB_>(TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, \ + int const multi_processor_count, cudaStream_t stream, \ + int* kernel_occupancy, size_t* workspace_size, \ + cute::Shape dynamic_cluster_shape, \ + cute::Shape fallback_cluster_shape); } // namespace cutlass_kernels } // namespace kernels From 9047135340a7dbc353e37f937e1909e17b9d3fc3 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 8 Oct 2025 21:54:10 +0000 Subject: [PATCH 02/17] >generate_kernels.py --- .../jit/gemm/cutlass/generate_kernels.py | 234 +++++++++++++----- 1 file changed, 168 insertions(+), 66 deletions(-) diff --git a/flashinfer/jit/gemm/cutlass/generate_kernels.py b/flashinfer/jit/gemm/cutlass/generate_kernels.py index e767361a65..5bc6dea01e 100644 --- a/flashinfer/jit/gemm/cutlass/generate_kernels.py +++ b/flashinfer/jit/gemm/cutlass/generate_kernels.py @@ -144,6 +144,8 @@ def __init__( epi_schedule, epi_fusion=None, is_mx_fpx=False, + dynamic_cga=False, + swap_ab=False, ): self.gemm_kind = gemm_kind self.arch = arch @@ -158,10 +160,12 @@ def __init__( self.warp_shape = warp_shape self.stages = stages self.cga_shape = cga_shape + self.dynamic_cga = dynamic_cga self.mainloop_schedule = mainloop_schedule self.epi_schedule = epi_schedule self.epi_fusion = epi_fusion self.is_mx_fpx = is_mx_fpx + self.swap_ab = swap_ab def __repr__(self): kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format( @@ -183,13 +187,15 @@ def __repr__(self): self.stages, ) - hopper_suffix = "_{}x{}x{}{}{}{}".format( + hopper_suffix = "_{}x{}x{}{}{}{}{}{}".format( self.cga_shape[0], self.cga_shape[1], self.cga_shape[2], KernelScheduleSuffixes[self.mainloop_schedule], EpilogueScheduleSuffixes[self.epi_schedule], EpiFusionSuffixes[self.epi_fusion], + "_mxfpx_" if self.is_mx_fpx else "", + "_swap_ab" if self.swap_ab else "", ) if self.arch >= 90: @@ -217,7 +223,9 @@ def instantiate_operation_tma_warp_specialized(operation): cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) kernel_sched = KernelScheduleTag[operation.mainloop_schedule] - epi_sched = EpilogueScheduleTag[operation.epi_schedule] + epi_sched = "void" + if operation.epi_schedule is not None: + epi_sched = EpilogueScheduleTag[operation.epi_schedule] if operation.gemm_kind == GemmKind.Gemm: weight_tag = DataTypeTag[operation.weight_type] @@ -228,8 +236,7 @@ def instantiate_operation_tma_warp_specialized(operation): {kernel_sched}, {epi_sched}> ( const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* -); -""" +);""" elif operation.gemm_kind == GemmKind.Grouped: if operation.act_type != operation.weight_type and ( operation.act_type != DataType.e4m3 or operation.weight_type != e2m1 @@ -247,18 +254,21 @@ def instantiate_operation_tma_warp_specialized(operation): KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, ] - assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized kernel_sched.replace("::Kernel", "::KernelGrouped") - epi_sched += "Grouped" - + # epi_sched += "Grouped" # arch_tag = f"cutlass::arch::Sm{operation.arch}" arch_tag = f"Sm{operation.arch}" weight_tag = CudaTypeName[operation.weight_type] assert operation.epi_fusion is not None epi_fusion = EpiFusion[operation.epi_fusion] + # We need to remove the '::' because this will break the instantiation macro epi_fusion = epi_fusion.split(":")[-1] epi_tag = epi_tag.split(":")[-1] + epi_sched = epi_sched.split(":")[-1] + epi_sched = epi_sched.replace( + "1Sm", "" + ) # Hack to WAR missing `PtrArrayTmaWarpSpecialized` type guard_map = { e2m1: "defined(ENABLE_FP4)", @@ -274,11 +284,9 @@ def instantiate_operation_tma_warp_specialized(operation): # (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*); # """ instantiation = f""" -#if {guard_act} && {guard_weight}\n - INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n -#endif -""" +#if {guard_act} && {guard_weight} + INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_sched}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, {"true" if operation.dynamic_cga else "false"}, false, {"true" if operation.swap_ab else "false"}); +#endif""" return instantiation @@ -289,8 +297,7 @@ def instantiate_operation_sm80(operation): instantiation = f""" template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}> - ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); - """ + ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);""" return instantiation @@ -318,12 +325,12 @@ def get_file_content(launcher_inl_files, operations): {{ namespace kernels {{ -namespace cutlass_kernels +namespace cutlass_kernels_oss {{ {instantiations} -}} // namespace cutlass_kernels +}} // namespace cutlass_kernels_oss }} // namespace kernels }} // namespace tensorrt_llm """ @@ -353,17 +360,28 @@ def write_file(launcher_inl_files, operations, output_file): f.write(content) -from operator import mul, truediv - +def is_gemm_op_valid_sm100(op): + # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future + tile_m, tile_n, _ = op.cta_shape + cga_m, cga_n, cga_k = op.cga_shape -def elementwise(x, y, f): - return tuple(f(a, b) for (a, b) in zip(x, y)) + if ( + op.epi_fusion == TrtLlm_EpilogueFusion.epilogue_fusion_finalize + and op.epi_schedule != EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + ): + return False + # We use a runtime cluster shape for SM100, so we only use cluster shapes to distinguish between 1SM and 2SM variants. + if cga_m > 2 or cga_n != 1 or cga_k != 1: + return False -def is_gemm_op_valid_sm100(op): - # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future - tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv) - cga_m, cga_n, _ = op.cga_shape + if op.arch == 103: + return ( + op.act_type == e2m1 + and op.weight_type == e2m1 + and tile_m == 128 + and tile_n in [128, 256] + ) # Default shapes # This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA @@ -372,23 +390,23 @@ def is_gemm_op_valid_sm100(op): # FP4 Has some much more limited sizes if op.act_type == e2m1 or op.weight_type == e2m1: - # TODO 128x256x256 FP4 compiles but crashes - # if tile_n % 64 != 0 or tile_n < 128: - # return False if tile_n not in [64, 128, 256] or tile_m != 128: return False + # TODO Revert this once cutlass adds support for blockscaled + no smem + if ( + op.arch == 100 + and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm + ): + return False # Shapes for fp8 small N shapes if ( - op.act_type == DataType.e4m3 + (op.act_type == DataType.e4m3) and (tile_n == 16 or tile_n == 8) and (cga_m == 1 and cga_n == 1) ): - # todo: double check why this is disable in CUTLASS backend. @yuhan - if tile_m == 128 and tile_n == 8: - return False - else: - return True + # todo: double check why tile_n = 8 is disabled in CUTLASS backend. @yuhan + return tile_m != 128 or tile_n % 16 == 0 # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: @@ -427,7 +445,10 @@ def is_grouped_gemm_op_valid(op): if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default: return False - if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: + if ( + op.epi_schedule is not None + and op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized + ): return False if op.mainloop_schedule not in [ @@ -543,14 +564,30 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] + swap_ab = [True, False] + cga_shapes = product([1, 2], [1, 2], [1]) partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mn, + cga_shapes, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mn, + cga_shape, + swap_ab, + ) in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // GetDataTypeBits(dtype) cta_shape_mnk = cta_shape_mn + (cta_shape_k,) @@ -560,7 +597,7 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): if dtype != DataType.e4m3 else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum ) - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized + epi_schedule = None otypes = [dtype] if dtype == DataType.e4m3: @@ -584,6 +621,7 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): mainloop_schedule, epi_schedule, epi_fusion, + swap_ab=swap_ab, ) if is_op_valid(moe_gemm_operation): @@ -693,8 +731,6 @@ def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype): cta_shape_k = max_k_bits // GetDataTypeBits(dtype) if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8): cta_shape_k = 256 - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16): - cta_shape_k = 128 return cta_shape_mn + (cta_shape_k,) @@ -702,7 +738,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): if not is_arch_enabled: return [] arch = 120 - supported_dtypes = [e2m1] + supported_dtypes = [e2m1, (DataType.e4m3, e2m1)] quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] cta_shapes_mnk = [ @@ -717,45 +753,71 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] cga_shapes = [[1, 1, 1]] + swap_ab = [True, False] + partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mnk, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mnk, + cga_shapes, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args: - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) - + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mnk, + cga_shape, + swap_ab, + ) in partial_args: # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized + epi_schedule = None - otypes = [dtype] - if dtype in [DataType.e4m3, e2m1]: + if isinstance(dtype, tuple): + act_type, weight_type = dtype + else: + act_type, weight_type = dtype, dtype + + # Minimal filter: for mixed FP8xFP4 on SM120, only emit 128x128x128 + if act_type == DataType.e4m3 and weight_type == e2m1: + if cta_shape_mnk != [128, 128, 128]: + continue + + otypes = [act_type] + if act_type in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, arch, - dtype, - dtype, - dtype, - dtype, + act_type, + weight_type, + act_type, + act_type, otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion, + is_mx_fpx=(act_type == DataType.e4m3 and weight_type == e2m1), + swap_ab=swap_ab, ) operations.append(moe_gemm_operation) @@ -767,10 +829,9 @@ def generate_sm120_operations(is_arch_enabled): return operations -def generate_sm100_grouped_gemm_operations(is_arch_enabled): +def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): if not is_arch_enabled: return [] - arch = 100 supported_dtypes = [ DataType.f16, DataType.bf16, @@ -782,7 +843,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] cta_shapes_m = [64, 128] - cta_shapes_n = [8, 16, 32, 64, 128, 256] + cta_shapes_n = [8, 16, 32, 64, 128, 192, 256] cta_shapes_mn = product(cta_shapes_m, cta_shapes_n) warp_shape = [0, 0, 0] # ignored except for naming @@ -790,28 +851,55 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize + TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] - cga_shapes = list(product([1, 2], [1, 2], [1])) + # Some shapes for SM100 are better with NoSmem, note the kernel will internally map to the 1 or 2 SM variants based on the cga_shape[0] + epi_schedules = [ + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + ] + + # We will use dynamic cluster shapes for SM100, so we only need to indicate if we are using 1 or 2 SM version + cga_shapes = [(1, 1, 1), (2, 1, 1)] + + swap_ab = [True, False] + + dynamic_cga = [True, False] partial_args = product( - supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes + supported_dtypes, + quant_ops, + epi_tags, + epi_fusions, + cta_shapes_mn, + cga_shapes, + epi_schedules, + dynamic_cga, + swap_ab, ) operations = list() - for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: + for ( + dtype, + quant_op, + epi_tag, + epi_fusion, + cta_shape_mn, + cga_shape, + epi_schedule, + dynamic_cga, + swap_ab, + ) in partial_args: if isinstance(dtype, tuple): dtype, weight_type = dtype else: weight_type = dtype cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype) - cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative - epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized otypes = [dtype] if dtype in [DataType.e4m3, e2m1]: @@ -828,7 +916,7 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): otype, quant_op, epi_tag, - cga_tile_shape_mnk, + cta_shape_mnk, warp_shape, stages, cga_shape, @@ -836,6 +924,8 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): epi_schedule, epi_fusion, is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1), + dynamic_cga=dynamic_cga, + swap_ab=swap_ab, ) if is_op_valid(moe_gemm_operation): @@ -843,8 +933,13 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): return operations +def generate_sm103_operations(is_arch_enabled): + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 103) + return operations + + def generate_sm100_operations(is_arch_enabled): - operations = generate_sm100_grouped_gemm_operations(is_arch_enabled) + operations = generate_sm100_grouped_gemm_operations(is_arch_enabled, 100) return operations @@ -908,18 +1003,25 @@ def generate_gemm_operations(output_dir, architectures): (GemmKind.Gemm, 90): [fpA_intB_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], (GemmKind.Grouped, 100): [moe_gemm_inl], + (GemmKind.Grouped, 103): [moe_gemm_inl], (GemmKind.Grouped, 120): [moe_gemm_inl], (GemmKind.Grouped, 80): [sm80_moe_gemm_inl], } def has_arch(sm): - return f"{sm}" in arches or f"{sm}-real" in arches + return ( + f"{sm}" in arches + or f"{sm}-real" in arches + or f"{sm}f-real" in arches + or f"{sm}f" in arches + ) # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. operations = [] operations += generate_sm120_operations(has_arch(120) or has_arch(121)) - operations += generate_sm100_operations(has_arch(100)) + operations += generate_sm103_operations(has_arch(103)) + operations += generate_sm100_operations(has_arch(100) or has_arch(103)) operations += generate_sm90_operations(has_arch(90)) operations += generate_sm80_operations(has_arch(80) or has_arch(89)) From d9d77232d3cefaffa07bdd01ebb2c0d3b03bea32 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 8 Oct 2025 21:58:16 +0000 Subject: [PATCH 03/17] >generate_kernels.py --- flashinfer/jit/gemm/cutlass/generate_kernels.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flashinfer/jit/gemm/cutlass/generate_kernels.py b/flashinfer/jit/gemm/cutlass/generate_kernels.py index 5bc6dea01e..f7a87bedbd 100644 --- a/flashinfer/jit/gemm/cutlass/generate_kernels.py +++ b/flashinfer/jit/gemm/cutlass/generate_kernels.py @@ -277,6 +277,11 @@ def instantiate_operation_tma_warp_specialized(operation): } guard_act = guard_map.get(operation.act_type, "1") guard_weight = guard_map.get(operation.weight_type, "1") + + is_mx_fpx = str(operation.is_mx_fpx).lower() + use_dynamic_cga = str(operation.dynamic_cga).lower() + use_bias = str(False).lower() + swap_ab = str(operation.swap_ab).lower() # TODO Revert this once compiler bug is fixed so we can use template instead of macro again # instantiation = f""" # template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag}, @@ -285,7 +290,10 @@ def instantiate_operation_tma_warp_specialized(operation): # """ instantiation = f""" #if {guard_act} && {guard_weight} - INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, {epi_sched}, {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, {"true" if operation.dynamic_cga else "false"}, false, {"true" if operation.swap_ab else "false"}); + INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, + {epi_sched}, {epi_tag}, {epi_fusion}, + {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, + {is_mx_fpx}, {use_dynamic_cga}, {use_bias}, {swap_ab}); #endif""" return instantiation From 31662459986eca0fa3f0a328e6c872af0c48ab45 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 8 Oct 2025 22:51:09 +0000 Subject: [PATCH 04/17] >launcher.inl --- .../launchers/moe_gemm_tma_ws_launcher.inl | 75 ++++++++++++------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index b0fc36069e..db5788bfdd 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,13 @@ #include #include +#include "../../include/moe_gemm_kernels.h" #include "../moe_tma_warp_specialized_traits.h" #include "cute/tensor.hpp" #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -33,14 +33,7 @@ #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_ref.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" -#include "moe_gemm_kernels.h" +#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp" #include "moe_gemm_tma_ws_launcher.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -58,7 +51,8 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { +namespace cutlass_kernels_oss { +using namespace tensorrt_llm::kernels::cutlass_kernels; using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; // Constructs an object with specific arguments only if flag is true @@ -76,8 +70,18 @@ ReturnType construct_if_true(Args&&... args) { template auto deduce_layout_sf() { if constexpr (FLAG && A) { + // In moe_kernels.cu we rely on these two types being the same. This is not necessarily + // guaranteed by cutlass so we have a sanity check here. + static_assert(std::is_same_v, + "Deduced layout SF does not match for A and B"); return typename GemmGrouped::GemmKernel::CollectiveMainloop::LayoutSFA{}; } else if constexpr (FLAG && !A) { + // In moe_kernels.cu we rely on these two types being the same. This is not necessarily + // guaranteed by cutlass so we have a sanity check here. + static_assert(std::is_same_v, + "Deduced layout SF does not match for A and B"); return typename GemmGrouped::GemmKernel::CollectiveMainloop::LayoutSFB{}; } else { return (void*)nullptr; @@ -85,18 +89,21 @@ auto deduce_layout_sf() { } template + typename EpilogueSchedule, typename EpilogueTag, EpilogueFusion FUSION, + typename TileShape, typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, + bool SwapAB> struct DispatchToTmaWSFunction {}; // TMA WS specialized version template + typename EpilogueSchedule, typename EpilogueTag, EpilogueFusion FUSION, + typename TileShape, typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, + bool SwapAB> void tma_warp_specialized_generic_moe_gemm_kernelLauncher( TmaWarpSpecializedGroupedGemmInput tma_ws_input, int num_experts, int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, - size_t* workspace_size) { + size_t* workspace_size, cute::Shape dynamic_cluster_shape, + cute::Shape fallback_cluster_shape) { if constexpr (ArchTag::kMinComputeCapability < 90) { TLLM_THROW("Invalid architecture instantiated"); } @@ -115,6 +122,14 @@ void tma_warp_specialized_generic_moe_gemm_kernelLauncher( "build_wheel.py."); } #endif +#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS + else if constexpr (ArchTag::kMinComputeCapability == 103) { + // fallback sm100f logic is done in dispatchMoeGemmFinalDispatchTmaWarpSpecialized + TLLM_THROW( + "Please recompile with support for blackwell by passing 103-real as an arch to " + "build_wheel.py."); + } +#endif #ifndef COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS else if constexpr (ArchTag::kMinComputeCapability >= 120) { TLLM_THROW( @@ -123,10 +138,13 @@ void tma_warp_specialized_generic_moe_gemm_kernelLauncher( } #endif else { - return DispatchToTmaWSFunction::op(tma_ws_input, num_experts, multi_processor_count, - stream, kernel_occupancy, workspace_size); + return DispatchToTmaWSFunction::op(tma_ws_input, num_experts, + multi_processor_count, stream, + kernel_occupancy, workspace_size, + dynamic_cluster_shape, + fallback_cluster_shape); } } @@ -164,6 +182,8 @@ using SafeBF16 = __nv_bfloat16; using SafeBF16 = void; #endif +using namespace cutlass::epilogue; + // TODO Revert this back to a template instantiation once compiler bug is resolved #define INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM( \ ArchTag_, DataType_, WeightType_, OutputType_, EpilogueSchedule_, EpilogueTag_, FUSION_, \ @@ -286,10 +306,15 @@ using SafeBF16 = void; using ElementSF = std::conditional_t< \ IsMXFPX, cutlass::float_ue8m0_t, \ cutlass::float_ue4m3_t>; /*TmaWarpSpecializedGroupedGemmInput::ElementSF;*/ \ - using ElementActBlockScaled = std::conditional_t, \ - cute::tuple>; \ + using ElementActBlockScaled = \ + std::conditional_t, \ + cutlass::nv_float4_t>, \ + cute::tuple>; \ using ElementWeightBlockScaled = \ - std::conditional_t, \ + std::conditional_t, \ + cutlass::nv_float4_t>, \ cute::tuple>; \ \ /* Activation matrix alignment */ \ @@ -704,6 +729,6 @@ using SafeBF16 = void; cute::Shape dynamic_cluster_shape, \ cute::Shape fallback_cluster_shape); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm From 4dce85b59dfb4f8ea50ac182ec26e0d6e9e5ff79 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 8 Oct 2025 23:00:37 +0000 Subject: [PATCH 05/17] >moe_gemm_kernels.h --- .../include/moe_gemm_kernels.h | 135 ++++++++---------- 1 file changed, 57 insertions(+), 78 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index b85decebcd..916e8ab78f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -15,10 +15,10 @@ */ #pragma once -#include #include #include +#include #include #include "./common.h" @@ -32,17 +32,10 @@ #include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" #ifdef ENABLE_FP4 -#if CUDA_VERSION >= 12080 #include #endif -#endif namespace tensorrt_llm::kernels::cutlass_kernels { -template -constexpr auto transpose_stride(T const& t) { - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), - cute::get<1>(t)); -} template struct GroupedGemmInput { @@ -71,8 +64,6 @@ struct GroupedGemmInput { }; struct TmaWarpSpecializedGroupedGemmInput { - template - using TransposeStride = decltype(transpose_stride(T{})); template using TransposeLayoutTag = std::conditional_t, @@ -83,14 +74,24 @@ struct TmaWarpSpecializedGroupedGemmInput { static_assert( std::is_same_v>); - // Layout for A and B is transposed and then swapped in the implementation - // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM - using LayoutA = - TransposeLayoutTag; // Layout type for A matrix operand - using LayoutB = - TransposeLayoutTag; // Layout type for B matrix operand - using LayoutC = - TransposeLayoutTag; // Layout type for C matrix operand + // These are always the layout of A & B matrices, activations and weights will be assigned to + // either A or B based on swap_ab + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + // When using Swap A&B we need to transpose the output matrix + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutC_T = TransposeLayoutTag; + using LayoutD_T = TransposeLayoutTag; + + using StrideA = std::remove_pointer_t>; + using StrideB = std::remove_pointer_t>; + + using StrideC = std::remove_pointer_t>; + using StrideD = std::remove_pointer_t>; + using StrideC_T = std::remove_pointer_t>; + using StrideD_T = std::remove_pointer_t>; constexpr static int NVFP4BlockScaleVectorSize = 16; constexpr static int MXFPXBlockScaleVectorSize = 32; @@ -121,14 +122,6 @@ struct TmaWarpSpecializedGroupedGemmInput { return (dim + alignment - 1) / alignment * alignment; } - using StrideA = - std::remove_pointer_t>; // Use B because they will - // be swapped - using StrideB = - std::remove_pointer_t>; // Use A because they will - // be swapped - using StrideC = std::remove_pointer_t>; - #ifdef ENABLE_FP8 template constexpr static bool IsFP8_v = @@ -144,47 +137,40 @@ struct TmaWarpSpecializedGroupedGemmInput { using ProblemShape = cutlass::gemm::GroupProblemShape>; + bool swap_ab = false; ProblemShape shape_info{}; - StrideA* stride_a = nullptr; - StrideB* stride_b = nullptr; + void* stride_act = nullptr; + void* stride_weight = nullptr; - void const** ptr_a = nullptr; - void const** ptr_b = nullptr; + void const** ptr_act = nullptr; + void const** ptr_weight = nullptr; // C is currently the same in both epilogues - StrideC* stride_c = nullptr; + void* stride_c = nullptr; void const** ptr_c = nullptr; - struct DefaultEpilogue { - using LayoutD = - TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; + // D is used in all cases except fused finalize + void* stride_d = nullptr; + void** ptr_d = nullptr; struct FusedFinalizeEpilogue { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; + using StrideFinalOutput_T = cutlass::detail::TagToStrideC_t; + using StrideFinalOutput = cutlass::detail::TagToStrideC_t; void* ptr_final_output = nullptr; + StrideFinalOutput_T stride_final_output_transposed{}; StrideFinalOutput stride_final_output{}; - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; + void const** ptr_bias = nullptr; + float const** ptr_router_scales = nullptr; - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; + int const** ptr_source_token_index = nullptr; + int num_rows_in_final_output = 0; + int shape_override = -1; - size_t num_rows_in_final_output = 0; + bool use_reduction = true; }; - DefaultEpilogue default_epilogue; FusedFinalizeEpilogue fused_finalize_epilogue; enum class EpilogueFusion { NONE, ACTIVATION, GATED_ACTIVATION, FINALIZE }; @@ -195,11 +181,11 @@ struct TmaWarpSpecializedGroupedGemmInput { using ElementSF = uint8_t; using MXFPXElementSF = ElementSF; // Just an alias for now using NVFP4ElementSF = ElementSF; // Just an alias for now - ElementSF const** fpX_block_scaling_factors_A = nullptr; - ElementSF const** fpX_block_scaling_factors_B = nullptr; + ElementSF const** fpX_block_scaling_factors_act = nullptr; + ElementSF const** fpX_block_scaling_factors_weight = nullptr; - void* fpX_block_scaling_factors_stride_A = nullptr; - void* fpX_block_scaling_factors_stride_B = nullptr; + void* fpX_block_scaling_factors_stride_act = nullptr; + void* fpX_block_scaling_factors_stride_weight = nullptr; enum class FpXBlockScalingType { MXFPX, NVFP4, NONE }; FpXBlockScalingType fpX_block_scaling_type = FpXBlockScalingType::NONE; @@ -230,22 +216,17 @@ struct TmaWarpSpecializedGroupedGemmInput { uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; - // Whether to enable PDL (Programmatic Dependent Launch). - bool enable_pdl; - - static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size, FpXBlockScalingType scaling_type); - bool isValid() const { return stride_a != nullptr && ptr_a != nullptr; } + bool isValid() const { return stride_act != nullptr && ptr_act != nullptr; } - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, - int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); + void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, + bool use_reduction); std::string toString() const; }; @@ -264,7 +245,6 @@ class MoeGemmRunner { public: MoeGemmRunner(); -#if defined(ENABLE_FP4) #if defined(ENABLE_BF16) static constexpr bool use_wfp4a16 = std::is_same_v && (std::is_same_v || std::is_same_v); @@ -272,10 +252,6 @@ class MoeGemmRunner { static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; #endif -#else - static constexpr bool use_wfp4a16 = false; -#endif - #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) && @@ -289,17 +265,16 @@ class MoeGemmRunner { #else static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; - static constexpr bool use_wfp4afp4 = false; #endif static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool use_fp4 = std::is_same_v; - static constexpr bool use_wfp4afp4 = + static constexpr bool use_wfp4afp8 = std::is_same_v && std::is_same_v; #else static constexpr bool use_fp4 = false; - static constexpr bool use_wfp4afp4 = false; + static constexpr bool use_wfp4afp8 = false; #endif void moeGemmBiasAct(GroupedGemmInput inputs, @@ -308,15 +283,19 @@ class MoeGemmRunner { void moeGemm(GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs); - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getTmaWarpSpecializedConfigs(int sm); - static std::vector getBlackwellConfigs(int sm); - static std::vector getHopperConfigs(int sm); + std::vector getConfigs( + bool supports_finalize_fusion) const; + static std::vector getConfigs( + int sm, bool supports_finalize_fusion); + static std::vector getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion); static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsTmaWarpSpecialized() const; + + [[nodiscard]] bool supportsTmaWarpSpecialized() const { return supportsTmaWarpSpecialized(sm_); } + + [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; From d47c865282bd16f25885b1821a834085020bac1d Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Sat, 11 Oct 2025 02:07:12 +0000 Subject: [PATCH 06/17] cutlass_fused_moe_kernels.cuh is troublesome... --- .../cutlass_fused_moe_instantiation.cu | 6 +- .../cutlass_fused_moe_kernels.cuh | 861 +++++++++--------- .../include/tensorrt_llm/common/cudaUtils.h | 3 + .../epilogue/fusion/sm90_visitor_scatter.hpp | 757 +++++++++++++++ .../fpA_intB_gemm/fpA_intB_gemm_template.h | 49 +- .../fpA_intB_gemm_template_sm90.h | 23 +- .../launchers/fpA_intB_launcher_sm90.h | 4 +- .../launchers/fpA_intB_launcher_sm90.inl | 23 +- .../launchers/fused_moe_gemm_launcher_sm80.h | 22 +- .../fused_moe_gemm_launcher_sm80.inl | 49 +- .../launchers/moe_gemm_tma_ws_launcher.h | 17 +- .../launchers/moe_gemm_tma_ws_launcher.inl | 2 +- .../moe_gemm_tma_ws_mixed_input_launcher.h | 33 +- .../moe_gemm_tma_ws_mixed_input_launcher.inl | 76 +- .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_fp4.cu | 4 +- .../moe_gemm/moe_gemm_kernels_bf16_fp8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_fp4.cu | 6 +- .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp4_fp4.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp8_fp4.cu | 2 +- .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 22 +- .../moe_gemm/moe_gemm_kernels_fp8_uint4.cu | 22 +- .../moe_gemm/moe_gemm_template_dispatch.h | 236 +++-- .../moe_gemm_template_dispatch_tma_ws.h | 222 +++-- ...emm_template_dispatch_tma_ws_mixed_dtype.h | 20 +- .../moe_gemm_tma_warp_specialized_input.cu | 123 +-- .../moe_tma_warp_specialized_traits.h | 25 +- 33 files changed, 1881 insertions(+), 924 deletions(-) create mode 100644 csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index f20729f163..50dfcf78b9 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -17,8 +17,6 @@ #include "cutlass_fused_moe_kernels.cuh" #include "moe_kernels.h" -namespace tensorrt_llm::kernels::cutlass_kernels { -// ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 @@ -38,6 +36,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>; #endif #endif #ifdef ENABLE_FP4 @@ -54,4 +53,5 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _ template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif -}; // namespace tensorrt_llm::kernels::cutlass_kernels +} +; // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 20c1ca3fd6..162e38bc65 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -284,6 +284,7 @@ void buildMinLatencyActiveExpertMaps( num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, smart_routing, cluster_rank, cluster_size, num_experts_smem); } + template __global__ void fusedBuildExpertMapsSortFirstTokenKernel( int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, @@ -983,7 +984,7 @@ __device__ auto quantizePackedFPXValue( cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); + QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor auto func = [&]() { @@ -1007,7 +1008,8 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, + bool const swizzled_input_sf = true) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; // We need to offset into the scaling factors for just this expert @@ -1024,15 +1026,28 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); + QuantizationSFLayout::SWIZZLED); if (sf_out) { if (input_sf) { - auto const sf_in = cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - *sf_out = *sf_in; + if (swizzled_input_sf) { + auto const sf_in = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, + const_cast(input_sf), + QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } else { + auto const sf_in = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, + const_cast(input_sf), + QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } } else { *sf_out = 0x00; } @@ -1075,18 +1090,25 @@ __device__ void setupFP4BlockScalingFactors( TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* weight_block_scale, int64_t num_tokens_before_expert) { - assert(layout_info.fpX_block_scaling_factors_stride_A); - assert(layout_info.fpX_block_scaling_factors_stride_B); - - // M & N swapped for transpose - auto stride_a_ptr = reinterpret_cast( - layout_info.fpX_block_scaling_factors_stride_A); - auto stride_b_ptr = reinterpret_cast( - layout_info.fpX_block_scaling_factors_stride_B); - stride_a_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( - cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); - stride_b_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( - cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + assert(layout_info.fpX_block_scaling_factors_stride_act); + assert(layout_info.fpX_block_scaling_factors_stride_weight); + + auto stride_act_ptr = reinterpret_cast( + layout_info.fpX_block_scaling_factors_stride_act); + auto stride_weight_ptr = reinterpret_cast( + layout_info.fpX_block_scaling_factors_stride_weight); + if (layout_info.swap_ab) { + // M & N swapped for transpose + stride_act_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( + cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + stride_weight_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( + cute::make_shape((int)gemm_n, (int)gemm_m, (int)gemm_k, (int)1)); + } else { + stride_act_ptr[expert] = BSConfig::tile_atom_to_shape_SFA( + cute::make_shape((int)gemm_m, (int)gemm_n, (int)gemm_k, (int)1)); + stride_weight_ptr[expert] = BSConfig::tile_atom_to_shape_SFB( + cute::make_shape((int)gemm_m, (int)gemm_n, (int)gemm_k, (int)1)); + } // This assert validates our current assumption that A&B can be safely transposed without needing // to modify @@ -1099,30 +1121,51 @@ __device__ void setupFP4BlockScalingFactors( std::is_same_v ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - layout_info.fpX_block_scaling_factors_A[expert] = + layout_info.fpX_block_scaling_factors_act[expert] = fp4_act_flat + getOffsetActivationSF(expert, num_tokens_before_expert, gemm_k, scaling_type); - layout_info.fpX_block_scaling_factors_B[expert] = + layout_info.fpX_block_scaling_factors_weight[expert] = weight_block_scale + getOffsetWeightSF(expert, gemm_n, gemm_k, scaling_type); } __device__ void computeTmaWarpSpecializedInputStrides( TmaWarpSpecializedGroupedGemmInput& layout_info, int gemm_m, int gemm_n, int gemm_k, int64_t out_idx) { - layout_info.stride_a[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideA{}, cute::make_shape(gemm_m, gemm_k, 1)); - layout_info.stride_b[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideB{}, cute::make_shape(gemm_n, gemm_k, 1)); + if (layout_info.swap_ab) { + reinterpret_cast( + layout_info.stride_act)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideB{}, + cute::make_shape(gemm_m, gemm_k, 1)); + reinterpret_cast( + layout_info.stride_weight)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideA{}, + cute::make_shape(gemm_n, gemm_k, 1)); + } else { + reinterpret_cast( + layout_info.stride_act)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideA{}, + cute::make_shape(gemm_m, gemm_k, 1)); + reinterpret_cast( + layout_info.stride_weight)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideB{}, + cute::make_shape(gemm_n, gemm_k, 1)); + } if (layout_info.stride_c) { + // TODO Enable 1xN bias matrix as C assert(false && "CUTLASS does not support a 1xN bias"); - // layout_info.stride_c[out_idx] = cute::make_stride(0, cute::Int<1>{}, 0); - layout_info.stride_c[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::StrideC{}, cute::make_shape(1, gemm_n, 1)); } if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info.default_epilogue.stride_d[out_idx] = cutlass::make_cute_packed_stride( - TmaWarpSpecializedGroupedGemmInput::DefaultEpilogue::StrideD{}, - cute::make_shape(gemm_n, gemm_m, 1)); + if (layout_info.swap_ab) { + reinterpret_cast( + layout_info.stride_d)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideD_T{}, + cute::make_shape(gemm_n, gemm_m, 1)); + } else { + reinterpret_cast( + layout_info.stride_d)[out_idx] = + cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::StrideD{}, + cute::make_shape(gemm_m, gemm_n, 1)); + } } if (layout_info.int4_groupwise_params.enabled) { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride( @@ -1142,18 +1185,27 @@ __device__ void computeTmaWarpSpecializedInputPointers( TmaWarpSpecializedGroupedGemmInput& layout_info, int64_t gemm_m, int64_t gemm_n, int64_t gemm_k, int num_tokens_before_expert, int64_t expert, T const* in, WeightType const* weights, TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale, - ScaleBiasType const* bias, OutputType* output, int64_t const out_idx) { + ScaleBiasType const* bias, OutputType* output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, int64_t const out_idx) { // The input prior to this contains K elements per token, with `num_tokens_before_expert` tokens - layout_info.ptr_a[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); + layout_info.ptr_act[out_idx] = safe_inc_ptr(in, num_tokens_before_expert * gemm_k); // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info.ptr_b[out_idx] = safe_inc_ptr(weights, expert * (gemm_n * gemm_k)); + layout_info.ptr_weight[out_idx] = safe_inc_ptr(weights, expert * (gemm_n * gemm_k)); if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { // The output prior to this contains N elements per token, with `num_tokens_before_expert` // tokens - layout_info.default_epilogue.ptr_d[out_idx] = - safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + layout_info.ptr_d[out_idx] = safe_inc_ptr(output, num_tokens_before_expert * gemm_n); + } + if (layout_info.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) { + layout_info.fused_finalize_epilogue.ptr_source_token_index[expert] = + permuted_row_to_unpermuted_row + num_tokens_before_expert; + layout_info.fused_finalize_epilogue.ptr_router_scales[expert] = + router_scales + num_tokens_before_expert; + if (layout_info.fused_finalize_epilogue.ptr_bias != nullptr) { + layout_info.fused_finalize_epilogue.ptr_bias[expert] = bias + gemm_n * expert; + } } if (layout_info.int4_groupwise_params.enabled) { // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 @@ -1180,7 +1232,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel( TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* gemm1_output, - OutputType* gemm2_output) { + OutputType* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row) { // First, compute the global tid. We only need 1 thread per expert. int const expert = blockIdx.x * blockDim.x + threadIdx.x; if (expert >= num_experts_per_node) { @@ -1199,22 +1252,26 @@ __global__ void computeStridesTmaWarpSpecializedKernel( // M and N transposed since we are using the #tokens as the N dimension layout_info1.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, - gemm1_k); + TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape( + layout_info1.swap_ab ? gemm1_n : gemm_m, layout_info1.swap_ab ? gemm_m : gemm1_n, + gemm1_k); layout_info2.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, - gemm2_k); + TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape( + layout_info2.swap_ab ? gemm2_n : gemm_m, layout_info2.swap_ab ? gemm_m : gemm2_n, + gemm2_k); if (layout_info1.int4_groupwise_params.enabled) { layout_info1.int4_groupwise_params.shape.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::ProblemShapeInt:: - UnderlyingProblemShape(gemm1_n, gemm_m, gemm1_k); + UnderlyingProblemShape(layout_info1.swap_ab ? gemm1_n : gemm_m, + layout_info1.swap_ab ? gemm_m : gemm1_n, gemm1_k); } if (layout_info2.int4_groupwise_params.enabled) { layout_info2.int4_groupwise_params.shape.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::ProblemShapeInt:: - UnderlyingProblemShape(gemm2_n, gemm_m, gemm2_k); + UnderlyingProblemShape(layout_info2.swap_ab ? gemm2_n : gemm_m, + layout_info2.swap_ab ? gemm_m : gemm2_n, gemm2_k); } if (alpha_scale_flat1 && alpha_scale_flat2) { @@ -1256,142 +1313,12 @@ __global__ void computeStridesTmaWarpSpecializedKernel( layout_info1, gemm_m, gemm1_n, gemm1_k, num_tokens_before_expert, expert, gemm1_in, weights1, reinterpret_cast( quant_params.groupwise.fc1.weight_scales), - bias1, gemm1_output, expert); + bias1, gemm1_output, nullptr, nullptr, expert); computeTmaWarpSpecializedInputPointers( layout_info2, gemm_m, gemm2_n, gemm2_k, num_tokens_before_expert, expert, gemm2_in, weights2, reinterpret_cast( quant_params.groupwise.fc2.weight_scales), - bias2, gemm2_output, expert); -} - -template -__global__ void computeStridesTmaWarpSpecializedLowLatencyKernel( - TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, - int64_t gemm1_k, int64_t gemm2_n, int64_t gemm2_k, int64_t const num_experts_per_node, - T const* in1, T const* in2, WeightType const* weights1, WeightType const* weights2, - float const* alpha_scale_flat1, float const* alpha_scale_flat2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* output1, - OutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; - - if (expert >= num_experts_per_node) { - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // Note: expert is used to calculate the offset of the input and output - // local_expert is used to calculate the offset of the weight - auto const num_tokens_before_expert = expert * num_tokens; - bool const is_active_expert = expert < *num_active_experts_per; - int const local_expert = is_active_expert ? active_expert_global_ids[expert] - start_expert : -1; - auto const gemm_m = is_active_expert ? num_tokens : 0; - - // M and N transposed since we are using the #tokens as the N dimension - layout_info1.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, - gemm1_k); - layout_info2.shape_info.problem_shapes[expert] = - TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, - gemm2_k); - - if (alpha_scale_flat1) { - assert(alpha_scale_flat2); - if (is_active_expert) { - layout_info1.alpha_scale_ptr_array[expert] = alpha_scale_flat1 + local_expert; - layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + local_expert; - } else { - layout_info1.alpha_scale_ptr_array[expert] = nullptr; - layout_info2.alpha_scale_ptr_array[expert] = nullptr; - } - } - - if (quant_params.fp4.fc1.weight_block_scale) { - setupFP4BlockScalingFactors( - layout_info1, expert, gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, - quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc1 uses the same A input for all experts and the scaling - // factor B offsets from the local expert index - if (is_active_expert) { - layout_info1.fpX_block_scaling_factors_A[expert] = fp4_act_flat1; - layout_info1.fpX_block_scaling_factors_B[expert] = - quant_params.fp4.fc1.weight_block_scale + - getOffsetWeightSF(local_expert, gemm1_n, gemm1_k, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info1.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info1.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - - if (quant_params.fp4.fc2.weight_block_scale) { - setupFP4BlockScalingFactors( - layout_info2, expert, gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, - quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc2 scaling factor B offsets by the local expert index - if (is_active_expert) { - layout_info2.fpX_block_scaling_factors_B[expert] = - quant_params.fp4.fc2.weight_block_scale + - getOffsetWeightSF(local_expert, gemm2_n, gemm2_k, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info2.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info2.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif - - assert(gemm_m <= INT32_MAX); - assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); - assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); - assert(gemm2_n > 0 && gemm2_n <= INT32_MAX); - assert(gemm2_k > 0 && gemm2_k <= INT32_MAX); - computeTmaWarpSpecializedInputStrides(layout_info1, gemm_m, gemm1_n, gemm1_k, expert); - computeTmaWarpSpecializedInputStrides(layout_info2, gemm_m, gemm2_n, gemm2_k, expert); - - if (is_active_expert) { - // Note: under low latency mode, we use the same input for all experts - // so for gemm1, the inputs are the same, - // for gemm2, we use the input generated by gemm1 - layout_info1.ptr_a[expert] = in1; - layout_info2.ptr_a[expert] = safe_inc_ptr(in2, expert * num_tokens * gemm2_k); - - // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info1.ptr_b[expert] = safe_inc_ptr(weights1, local_expert * (gemm1_n * gemm2_k)); - layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); - - assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = - safe_inc_ptr(output1, expert * num_tokens * gemm1_n); - - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = - safe_inc_ptr(output2, expert * num_tokens * gemm2_n); - } - } else { - layout_info1.ptr_a[expert] = nullptr; - layout_info2.ptr_a[expert] = nullptr; - layout_info1.ptr_b[expert] = nullptr; - layout_info2.ptr_b[expert] = nullptr; - - layout_info1.default_epilogue.ptr_d[expert] = nullptr; - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; - } - } + bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert); } // ========================== Permutation things ======================================= @@ -1426,19 +1353,18 @@ __global__ void expandInputRowsKernel( int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); +#ifdef ENABLE_FP4 constexpr bool is_mxfp8 = std::is_same_v && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX && !PRE_QUANT_AWQ; constexpr bool is_mxfp8_input = is_mxfp8 && std::is_same_v; constexpr bool need_mxfp8_quant = is_mxfp8 && !is_mxfp8_input; - -#ifdef ENABLE_FP4 constexpr bool is_nvfp4 = std::is_same_v && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 && @@ -1446,6 +1372,9 @@ __global__ void expandInputRowsKernel( constexpr bool is_nvfp4_input = is_nvfp4 && std::is_same_v; constexpr bool need_nvfp4_quant = is_nvfp4 && !is_nvfp4_input; #else + constexpr bool is_mxfp8 = false; + constexpr bool is_mxfp8_input = false; + constexpr bool need_mxfp8_quant = false; constexpr bool is_nvfp4 = false; constexpr bool is_nvfp4_input = false; constexpr bool need_nvfp4_quant = false; @@ -1536,7 +1465,7 @@ __global__ void expandInputRowsKernel( "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, - fc1_act_sf_flat, input_sf); + fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1632,8 +1561,8 @@ void expandInputRowsKernelLauncher( int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, - bool enable_pdl, cudaStream_t stream) { + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( (std::is_same_v && fc1_act_sf_flat) || @@ -1672,7 +1601,7 @@ void expandInputRowsKernelLauncher( // Could be either regular FP8 or MXFP8 else if constexpr (std::is_same_v && std::is_same_v) { - TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); + TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ"); return quant_params.mxfp8_mxfp4.fc1.weight_block_scale ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, @@ -1714,21 +1643,22 @@ void expandInputRowsKernelLauncher( cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, - expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, + expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, + num_experts_per_node, reinterpret_cast(prequant_scales)); } -#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ - template void expandInputRowsKernelLauncher( \ - InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \ - float const* unpermuted_scales, float* permuted_scales, \ - int const* permuted_row_to_unpermuted_row, int64_t const num_rows, \ - int64_t const hidden_size, int const k, int const num_experts_per_node, \ - QuantParams const& quant_params, bool use_per_expert_act_scale, \ - int64_t* expert_first_token_offset, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - bool enable_pdl, cudaStream_t stream) +#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ + template void expandInputRowsKernelLauncher( \ + InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \ + float const* unpermuted_scales, float* permuted_scales, \ + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, \ + int64_t const hidden_size, int const k, int const num_experts_per_node, \ + QuantParams const& quant_params, bool use_per_expert_act_scale, \ + int64_t* expert_first_token_offset, \ + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \ + void const* prequant_scales, bool enable_pdl, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op // INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -1751,22 +1681,24 @@ template ::value, sizeof_bits::value); - assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0); - int64_t const start_offset = threadIdx.x; int64_t const stride = FINALIZE_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD; using BiasElem = cutlass::Array; using InputElem = cutlass::Array; @@ -1781,7 +1713,7 @@ __global__ void finalizeMoeRoutingKernel( #endif #pragma unroll - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + for (int elem_index = start_offset; elem_index < num_elems_in_orig_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1794,20 +1726,15 @@ __global__ void finalizeMoeRoutingKernel( int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - int64_t expanded_rows = num_rows * experts_per_token; - if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { - continue; - } - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; auto const* expanded_permuted_rows_row_ptr = - expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; + expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_padded_col; ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_padded_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1830,8 +1757,13 @@ __global__ void finalizeMoeRoutingNoFillingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const orig_cols, - int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const padded_cols, + int64_t const unpadded_cols, int64_t const experts_per_token, int const num_experts_per_node, + int const start_expert_id) { + assert(padded_cols % 4 == 0); + assert(unpadded_cols % 4 == 0); + assert(unpadded_cols <= padded_cols); + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif @@ -1860,17 +1792,16 @@ __global__ void finalizeMoeRoutingNoFillingKernel( continue; } - OutputType* reduced_row_ptr = reduced_unpermuted_output + source_row * orig_cols; + OutputType* reduced_row_ptr = reduced_unpermuted_output + source_row * unpadded_cols; // Load 128-bits per thread, according to the smallest data type we read/write constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / std::min(sizeof_bits::value, sizeof_bits::value); - assert(orig_cols % FINALIZE_ELEM_PER_THREAD == 0); - int64_t const start_offset = threadIdx.x; int64_t const stride = FINALIZE_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_padded_col = padded_cols / FINALIZE_ELEM_PER_THREAD; + int64_t const num_elems_in_orig_col = unpadded_cols / FINALIZE_ELEM_PER_THREAD; using BiasElem = cutlass::Array; using InputElem = cutlass::Array; @@ -1881,7 +1812,10 @@ __global__ void finalizeMoeRoutingNoFillingKernel( reinterpret_cast(expanded_permuted_rows); auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + for (int elem_index = start_offset; elem_index < num_elems_in_padded_col; + elem_index += stride) { + if (elem_index >= num_elems_in_orig_col) continue; // Skip writing beyond original columns + ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1893,22 +1827,17 @@ __global__ void finalizeMoeRoutingNoFillingKernel( int64_t const expanded_permuted_row_from_k_idx = unpermuted_row_to_permuted_row[source_row + k_idx * num_rows]; - int64_t valid_tokens = expert_first_token_offset[num_experts_per_node]; - if (expanded_permuted_row_from_k_idx < 0 || - expanded_permuted_row_from_k_idx >= valid_tokens) { - continue; - } float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; auto const* expanded_permuted_rows_row_ptr = - expanded_permuted_rows_v + expanded_permuted_row_from_k_idx * num_elems_in_col; + expanded_permuted_rows_v + expanded_permuted_row_from_k_idx * num_elems_in_padded_col; ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_padded_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1928,10 +1857,10 @@ void finalizeMoeRoutingKernelLauncher( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales, int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, - int64_t const experts_per_token, int64_t const num_experts_per_node, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, - cudaStream_t stream) { + int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const padded_cols, + int64_t const unpadded_cols, int64_t const experts_per_token, + int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool enable_pdl, cudaStream_t stream) { // Only add bias on rank 0 for tensor parallelism bool const is_rank_0 = parallelism_config.tp_rank == 0; ScaleBiasType const* bias_ptr = is_rank_0 ? bias : nullptr; @@ -1962,8 +1891,8 @@ void finalizeMoeRoutingKernelLauncher( ScaleMode::NO_SCALE>; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, cols, - experts_per_token, num_experts_per_node, start_expert_id); + token_selected_experts, expert_first_token_offset, num_rows, padded_cols, + unpadded_cols, experts_per_token, num_experts_per_node, start_expert_id); } else { // If all-gather reduce-scatter is used, finalizeMoeRouting must fill invalid output tokens with // zeros. @@ -1976,20 +1905,21 @@ void finalizeMoeRoutingKernelLauncher( : &finalizeMoeRoutingKernel; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, - final_scales, unpermuted_row_to_permuted_row, token_selected_experts, cols, - experts_per_token, num_experts_per_node, start_expert_id); + final_scales, unpermuted_row_to_permuted_row, token_selected_experts, + padded_cols, unpadded_cols, experts_per_token, num_experts_per_node, + start_expert_id); } } -#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ - template void finalizeMoeRoutingKernelLauncher( \ - GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, \ - ScaleBiasT const* bias, float const* final_scales, \ - int const* expanded_source_row_to_expanded_dest_row, \ - int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ - int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ - int64_t const experts_per_token, int64_t const num_experts_per_node, \ - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, \ +#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \ + template void finalizeMoeRoutingKernelLauncher( \ + GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, \ + ScaleBiasT const* bias, float const* final_scales, \ + int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, \ + int const* expert_for_source_row, int64_t const* expert_first_token_offset, \ + int64_t const num_rows, int64_t const padded_cols, int64_t const actual_cols, \ + int64_t const experts_per_token, int64_t const num_experts_per_node, \ + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, \ cudaStream_t stream); // // Instantiate the data types that are used by the external pytorch op @@ -2172,7 +2102,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x) { size_t gemm_result_offset = token * inter_size * gated_size_mul; size_t output_offset = token * inter_size; @@ -2188,6 +2117,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; + gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert] @@ -2245,7 +2175,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, linear_value + arrayConvert(bias_ptr_vec[elem_index]); } return fn(fc1_value, linear_value); - } else { return fn(fc1_value); } @@ -2379,7 +2308,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 &doActivationKernel, decltype(block_scaling_type)::value> // Identity - }; return fn_list[static_cast(activation_type.activation_type)]; }; @@ -2831,10 +2759,11 @@ void CutlassMoeFCRunnerepilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + permuted_token_final_scales_ = + gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe = moe_gemm_runner_.isFusedGatedActivation( @@ -2976,9 +2905,10 @@ void CutlassMoeFCRunner(use_ampere_activation_fusion ? output : intermediate_result), alpha_scale_ptr_array, /*occupancy*/ nullptr, - use_ampere_activation_fusion ? fc1_activation_type : ActivationType::Identity, + use_ampere_activation_fusion ? fc1_activation_type.activation_type + : ActivationType::Identity, expanded_num_rows, /*N*/ int64_t(fc1_out_size), /*K*/ hidden_size, @@ -3268,9 +3199,9 @@ void CutlassMoeFCRunner( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream); } else if (!using_tma_ws_gemm2) { finalizeMoeRoutingKernelLauncher( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -3597,16 +3530,16 @@ void CutlassMoeFCRunner void CutlassMoeFCRunner::runMoe( - void const* input_activations_void, void const* input_sf_void, + void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const inter_size, int const full_num_experts, - int const experts_per_token, char* workspace_ptr, void* final_output_void, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, + int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, + int const full_num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output_void, int* unpermuted_row_to_permuted_row, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { static constexpr bool int_scales_required = std::is_same::value || std::is_same::value || @@ -3661,6 +3594,27 @@ void CutlassMoeFCRunner::value)); } else { + // For NoSmem epilogue schedule, we need to align the output of the GEMM to 256 bits, for gated + // activation this is automatic if the usual alignment requirement is met + if (gemm1_config_->epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM && + !isGatedActivation(fc1_activation_type)) { + TLLM_CHECK_WITH_INFO( + inter_size % (256 / sizeof_bits::value) == 0, + "Inter size %d does not meet minimum alignment requirements for MOE GEMM %d", + (int)inter_size, (int)(256 / sizeof_bits::value)); + } + + if (gemm2_config_->epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM) { + TLLM_CHECK_WITH_INFO( + gemm2_config_->epilogue_fusion_type != + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE, + "Got NoSmem epilogue schedule, which is not supported for finalize fusion"); + TLLM_CHECK_WITH_INFO( + hidden_size % (256 / sizeof_bits::value) == 0, + "Hidden size %d does not meet minimum alignment requirements for MOE GEMM %d", + (int)hidden_size, (int)(256 / sizeof_bits::value)); + } + // Require at least 128 bits of alignment for MOE GEMM TLLM_CHECK_WITH_INFO( hidden_size % (128 / sizeof_bits::value) == 0, @@ -3752,10 +3706,11 @@ void CutlassMoeFCRunner:: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, bool enable_pdl, + float const* router_scales, int const* permuted_row_to_unpermuted_row, cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -3920,6 +3879,11 @@ CutlassMoeFCRunner:: layout_info2.ptr_c = nullptr; layout_info2.stride_c = nullptr; + layout_info1.fused_finalize_epilogue.ptr_bias = nullptr; + if (!bias2) { + layout_info2.fused_finalize_epilogue.ptr_bias = nullptr; + } + auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale : use_fp8 ? fp8_dequant1 @@ -3961,7 +3925,8 @@ CutlassMoeFCRunner:: layout_info2, num_tokens, expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node, gemm1_in, gemm2_in, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, - quant_params, bias1, bias2, gemm1_output, gemm2_output); + quant_params, bias1, bias2, gemm1_output, gemm2_output, router_scales, + permuted_row_to_unpermuted_row); return std::make_pair(layout_info1, layout_info2); } @@ -3982,55 +3947,7 @@ CutlassMoeFCRunner:: UnfusedGemmOutputType* output1, UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, bool enable_pdl, cudaStream_t stream) { - TLLM_CHECK_WITH_INFO(!use_w4_groupwise, - "W4AFP8 and WFP4A16 are not supported in low latency mode"); - - // Always nullptr - layout_info1.ptr_c = nullptr; - layout_info1.stride_c = nullptr; - layout_info2.ptr_c = nullptr; - layout_info2.stride_c = nullptr; - - auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale - : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale - : fp8_dequant1; - auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale - : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale - : fp8_dequant2; - if (!alpha_scale_flat1) { - layout_info1.alpha_scale_ptr_array = nullptr; - } - if (!alpha_scale_flat2) { - layout_info2.alpha_scale_ptr_array = nullptr; - } - - layout_info1.int4_groupwise_params.enabled = false; - layout_info2.int4_groupwise_params.enabled = false; - layout_info1.int4_groupwise_params.use_wfp4a16 = false; - layout_info2.int4_groupwise_params.use_wfp4a16 = false; - - int const threads = std::min(1024, num_experts); - int const blocks = (num_experts + threads - 1) / threads; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx( - &config, - computeStridesTmaWarpSpecializedLowLatencyKernel, - layout_info1, layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, - input1, input2, weights1, weights2, alpha_scale_flat1, alpha_scale_flat2, fc1_fp4_act_flat, - fc2_fp4_act_flat, quant_params, bias1, bias2, output1, output2, num_active_experts_per, - active_expert_global_ids, start_expert); - - return std::make_pair(layout_info1, layout_info2); + TLLM_THROW("Min latency mode is no longer supported"); } template :: setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, - int64_t inter_size, int64_t num_experts_per_node, - void const* input_activations_void, + int64_t unpadded_hidden_size, int64_t inter_size, + int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -4078,6 +3995,8 @@ CutlassMoeFCRunner:: gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + gemm1_tma_ws_input.swap_ab = true; + gemm2_tma_ws_input.swap_ab = true; TLLM_CHECK_WITH_INFO(gemm1_input != gemm1_output, "Input and output buffers are overlapping"); return Self::computeStridesTmaWarpSpecializedLowLatency( @@ -4095,17 +4014,28 @@ CutlassMoeFCRunner:: gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + gemm1_tma_ws_input.swap_ab = gemm1_config_->swap_ab; + gemm2_tma_ws_input.swap_ab = gemm2_config_->swap_ab; + TLLM_CHECK_WITH_INFO( + (gemm1_tma_ws_input.swap_ab && gemm2_tma_ws_input.swap_ab) || !use_w4_groupwise, + "Hopper w4 mixed input groupwise requires swap_ab"); + bool apply_bias = parallelism_config.tp_rank == 0; - bool using_hopper_fused_finalize = !use_deterministic_hopper_reduce_ && - gemm2_config_->sm_version == 90 && !use_w4_groupwise && - !use_lora; - if (using_hopper_fused_finalize) { + auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool gemm2_using_finalize_fusion = + gemm2_config_->epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool using_fused_finalize = + use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora; + TLLM_CHECK_WITH_INFO( + using_fused_finalize == gemm2_using_finalize_fusion, + "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it"); + if (using_fused_finalize) { assert(min_latency_mode == false); + bool use_reduction = expanded_num_rows > num_rows; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams( - final_output, permuted_token_final_scales_, expert_first_token_offset_, - permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr, hidden_size, - num_rows); + gemm2_tma_ws_input.setFinalizeFusionParams(final_output, unpadded_hidden_size, num_rows, + use_reduction); } // fp8_mxfp4 memsets the scaling factors to 1.0f @@ -4118,13 +4048,9 @@ CutlassMoeFCRunner:: TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{}; #ifdef ENABLE_FP8 -#if CUDA_VERSION >= 12080 __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf); std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp)); -#else - TLLM_CHECK_WITH_INFO(false, "WFP4AFP8 is not supported on CUDA "); -#endif #endif auto act_sf_rows = std::min(expanded_num_rows, num_rows * num_experts_per_node); @@ -4147,9 +4073,9 @@ CutlassMoeFCRunner:: reinterpret_cast(gemm1_input), reinterpret_cast(gemm2_input), fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, - fc1_expert_biases, fc2_expert_biases, - reinterpret_cast(gemm1_output), - reinterpret_cast(fc2_result_), enable_pdl, stream); + fc1_expert_biases, fc2_bias, reinterpret_cast(gemm1_output), + reinterpret_cast(fc2_result_), permuted_token_final_scales_, + permuted_row_to_unpermuted_row_, enable_pdl, stream); } } @@ -4409,7 +4335,7 @@ std::map> GemmProfilerBackend::getProfile if (is_tma_ws_input) { tma_ws_input_workspace_size = TmaWarpSpecializedGroupedGemmInput::workspaceSize(num_experts_per_node, mScalingType) * - (NUM_ROUTING_SAMPLES + 1); + (NUM_ROUTING_SAMPLES * NUM_FUSION_TYPES * NUM_SWAP_AB_TYPES + 1); if (is_w4afp8_quant || is_wfp4a16_quant) { quant_3_size = 0; @@ -4506,7 +4432,6 @@ std::map> GemmProfilerBackend::getProfile ADD(swiglu_alpha); ADD(swiglu_beta); ADD(swiglu_limit); - #undef ADD_NAME #undef ADD @@ -4637,13 +4562,32 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr } } -void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, - void const* expert_weights, bool enable_pdl, - cudaStream_t stream) { +void GemmProfilerBackend::prepareTmaWsInputs( + int num_tokens, char* workspace_ptr_char, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, bool enable_pdl, + cudaStream_t stream) { if (mSM < 90) { return; } + bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = + ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && + mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + bool const use_finalize_fusion = + fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; + bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || + use_w4_groupwise || + mGemmToProfile != GemmToProfile::GEMM_2; + if (use_finalize_fusion && finalize_fusion_not_supported) { + return; + } + + if (use_w4_groupwise && !swap_ab) { + return; + } + auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR(type, name) \ @@ -4681,11 +4625,19 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr dummy_tma_ws_input.enable_pdl = enable_pdl; // Set enable_pdl for dummy input tma_ws_input_workspace += tma_ws_size; + int workspace_index = + static_cast(use_finalize_fusion) * (NUM_SWAP_AB_TYPES * NUM_ROUTING_SAMPLES) + + static_cast(swap_ab) * NUM_ROUTING_SAMPLES; + tma_ws_input_workspace += workspace_index * tma_ws_size; + size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { - mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, - workspaces.at("gemm_workspace").first, mScalingType); - mTmaInputCache[i].enable_pdl = enable_pdl; // Set enable_pdl for the profiler + // Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same + // pointers to save space. + auto& cache_element = mTmaInputCache[use_finalize_fusion][swap_ab][i]; + cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, + workspaces.at("gemm_workspace").first, mScalingType); + cache_element.enable_pdl = enable_pdl; // Set enable_pdl for cache element tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = @@ -4694,34 +4646,27 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; auto& gemm1_tma_ws_input = - mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input; + mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input; auto& gemm2_tma_ws_input = - mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input; + mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input; if (mSM >= 90) { + auto fc1_output_size = + isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; + /* GEMM1 */ gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - bool apply_bias = true; - bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); - bool use_wfp4a16 = - ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && - mWType == nvinfer1::DataType::kUINT8); - bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + gemm1_tma_ws_input.swap_ab = swap_ab; + gemm2_tma_ws_input.swap_ab = swap_ab; - bool using_fused_finalize = !mInterface->use_deterministic_hopper_reduce_ && mSM == 90 && - !mMinLatencyMode && !use_w4_groupwise; - if (using_fused_finalize) { + if (use_finalize_fusion) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; - gemm2_tma_ws_input.setFinalizeFusionParams( - output, token_topk_unpermuted_scales, expert_first_token_offset, - permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr, mExpertHiddenSize, - num_tokens); + gemm2_tma_ws_input.setFinalizeFusionParams(output, mExpertUnpaddedHiddenSize, num_tokens, + mK > 1); } - auto fc1_output_size = - isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize; if (mMinLatencyMode) { std::tie(gemm1_tma_ws_input, gemm2_tma_ws_input) = mInterface->computeStridesTmaWarpSpecializedLowLatencyDispatch( @@ -4739,7 +4684,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr mExpertInterSize, mNumExpertsPerNode, input, input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, - enable_pdl, stream); + token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -4749,7 +4694,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights, bool enable_pdl, cudaStream_t stream) { - mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; auto workspace_size = getWorkspaceSize(num_tokens); @@ -4757,7 +4701,13 @@ void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, prepareRouting(num_tokens, workspace_ptr_char, enable_pdl, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); - prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, enable_pdl, stream); + for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE}) { + for (auto swap_ab : {false, true}) { + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, fusion, swap_ab, + enable_pdl, stream); + } + } } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { @@ -4772,7 +4722,7 @@ size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tactic, char* workspace_ptr_char, void const* expert_weights, - bool enable_pdl, cudaStream_t const& stream) { + cudaStream_t const& stream) { int64_t expanded_num_tokens = original_num_tokens * mK; int64_t num_experts_per_node = mNumExpertsPerNode; @@ -4824,54 +4774,85 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac TmaWarpSpecializedGroupedGemmInput tma_ws_input_template; if (tactic.is_tma_warp_specialized) { - tma_ws_input_template = mTmaInputCache[mSampleIndex]; + tma_ws_input_template = + mTmaInputCache[tactic.epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE] + [tactic.swap_ab][mSampleIndex]; + TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), + "TMA WS input template is not initialized"); } mInterface->is_profiler = true; if (mGemmToProfile == GemmToProfile::GEMM_1) { - mInterface->gemm1( - input, // - output, // - intermediate, // - expert_first_token_offset, // - tma_ws_input_template, // - weights_sel, // - bias, // - expert_first_token_offset + num_experts_per_node, // - mQuantParams.wo.fc1_weight_scales, // - mQuantParams.fp8.dequant_fc1, // - mQuantParams.fp8_mxfp4.fc2.act_global_scale ? mQuantParams.fp8_mxfp4.fc2.act_global_scale - : mQuantParams.fp8.quant_fc2, // - fp4_act_scale_flat, // - fp4_act_scale_flat, // - mQuantParams, // - original_num_tokens, // - expanded_num_tokens, // - mExpertHiddenSize, // - mExpertInterSize, // - num_experts_per_node, // - ActivationParams(mActivationType, swiglu_alpha, swiglu_beta, swiglu_limit), // - alpha_scale_ptr_array, // - !mUseLora, // - /*use_deepseek_fp8_block_scale=*/false, // - stream, // - tactic, // - mMinLatencyMode, // - num_active_experts_per_node, // - active_expert_global_ids, // - enable_pdl); // + mInterface->gemm1(input, // + output, // + intermediate, // + expert_first_token_offset, // + tma_ws_input_template, // + weights_sel, // + bias, // + expert_first_token_offset + num_experts_per_node, // + mQuantParams.wo.fc1_weight_scales, // + mQuantParams.fp8.dequant_fc1, // + mQuantParams.fp8_mxfp4.fc2.act_global_scale + ? mQuantParams.fp8_mxfp4.fc2.act_global_scale + : mQuantParams.fp8.quant_fc2, // + fp4_act_scale_flat, // + fp4_act_scale_flat, // + mQuantParams, // + original_num_tokens, // + expanded_num_tokens, // + mExpertHiddenSize, // + mExpertInterSize, // + num_experts_per_node, // + ActivationParams(mActivationType, swiglu_alpha, swiglu_beta, swiglu_limit), + alpha_scale_ptr_array, // + !mUseLora, // + /*use_deepseek_fp8_block_scale=*/false, // + stream, // + tactic, // + mMinLatencyMode, // + num_active_experts_per_node, // + active_expert_global_ids, // + enable_pdl); // } else { TLLM_CHECK(mGemmToProfile == GemmToProfile::GEMM_2); - mInterface->gemm2( - input, intermediate, output, expert_first_token_offset, tma_ws_input_template, weights_sel, - bias, mQuantParams.wo.fc2_weight_scales, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, - mQuantParams, token_topk_unpermuted_scales, token_topk_permuted_scales, - unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, - expert_first_token_offset + mNumExpertsPerNode, original_num_tokens, expanded_num_tokens, - mExpertHiddenSize, mExpertInterSize, num_experts_per_node, mK, alpha_scale_ptr_array, false, - nullptr, - /*use_deepseek_fp8_block_scale=*/false, stream, mParallelismConfig, mEnableAlltoall, tactic, - mMinLatencyMode, num_active_experts_per_node, active_expert_global_ids, enable_pdl); + mInterface->gemm2(input, // + intermediate, // + output, // + expert_first_token_offset, // + tma_ws_input_template, // + weights_sel, // + bias, // + mQuantParams.wo.fc2_weight_scales, // + mQuantParams.fp8.dequant_fc2, // + fp4_act_scale_flat, // + mQuantParams, // + token_topk_unpermuted_scales, // + token_topk_permuted_scales, // + unpermuted_row_to_permuted_row, // + permuted_row_to_unpermuted_row, // + token_selected_experts, // + expert_first_token_offset + mNumExpertsPerNode, // + original_num_tokens, // + expanded_num_tokens, // + mExpertHiddenSize, // + mExpertUnpaddedHiddenSize, // + mExpertInterSize, // + num_experts_per_node, // + mK, // + alpha_scale_ptr_array, // + false, // + nullptr, // + /*use_deepseek_fp8_block_scale=*/false, // + stream, // + mParallelismConfig, // + mEnableAlltoall, // + tactic, // + mMinLatencyMode, // + num_active_experts_per_node, // + active_expert_global_ids, // + enable_pdl); // } mInterface->is_profiler = false; diff --git a/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h b/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h index ccddbc1ef5..5f757f1b51 100644 --- a/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h +++ b/csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h @@ -1181,6 +1181,9 @@ using Int = ConstExprWrapper; template using Bool = ConstExprWrapper; +template +using ConstBool = ConstExprWrapper; + template struct TmaDescType; diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp new file mode 100644 index 0000000000..c98f7ee3c1 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp @@ -0,0 +1,757 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +template < + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm90ScatterPtrArray { + + using SmemShape = decltype(make_shape(size(make_layout(get<0>(EpilogueTile{}))), size(make_layout(get<1>(EpilogueTile{}))))); + using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, SmemShape{})); + + using ElementIndex = int32_t; + + static constexpr bool MajorMode = cutlass::gemm::detail::is_major<0,StrideOutput>() ? 0 : 1; + + using StrideIndex = decltype(replace<1-MajorMode>(Stride<_0,_0,_0>{}, Int<1>{})); + + struct SharedStorage {}; + + struct Arguments { + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut = {}; // output tensor stride + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + }; + + struct Params { + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut = {}; // output tensor stride + ElementIndex const* const* ptr_index{}; // per-group pointer to the scatter index + cutlass::FastDivmod index_divmod{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return { + args.ptr_out, + args.dOut, + args.ptr_index, + cutlass::FastDivmod(args.index_modulo), + args.shape_override, + args.use_reduction + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScatterPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class ArgsTuple + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple) + : args_tuple(std::move(args_tuple)) {} + + ArgsTuple args_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rOut_frg = recast>(coalesce(tC_rOut)); // (EPI_V) + tC_rOut_frg(epi_v) = convert_input(frg_input); + + return tC_rOut_frg(epi_v); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tC_rOut, tiled_r2s, tRG_gOut, tRG_cD, tiled_r2g_red, tiled_r2g_stg, use_reduction, thread_idx, residue_cD] = args_tuple; + + Tensor byte_buffer = recast(reduction_buffer); + static_assert(cosize(byte_buffer.layout()) * sizeof_bits_v >= cosize(SmemLayout{}) * sizeof_bits_v, + "Not enough space in scratch smem buffer"); + + Tensor sOut = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(recast_ptr(byte_buffer.data())), SmemLayout{})); + + auto thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sOut_epi = thread_r2s.partition_D(sOut); + Tensor tRS_rOut_epi = thread_r2s.retile_S(tC_rOut); + + auto thread_r2g = tiled_r2g_red.get_slice(thread_idx); + Tensor tRG_gOut_epi = tRG_gOut(_,_,_,epi_m,epi_n); + Tensor tRG_sOut_epi = thread_r2g.partition_D(sOut); + Tensor tRG_rOut_epi = thread_r2g.retile_S(make_tensor(tC_rOut.data(), shape(tRG_sOut_epi))); // reuse D registers + + // sanity check for register reuse + CUTE_STATIC_ASSERT_V(cosize(tC_rOut.layout()) == cosize(tRG_rOut_epi.layout()), "Invalid register count for R2G"); + + copy(tiled_r2s, tRS_rOut_epi, tRS_sOut_epi); + sync_fn(); + copy(tRG_sOut_epi, tRG_rOut_epi); + + auto residue = residue_cD; // capturing structured bindings is a C++20 feature + Tensor tRG_cD_epi = tRG_cD(0,_,_,epi_m,epi_n); + auto pred = cute::lazy::transform(tRG_cD_epi, [&](auto c){ return elem_less(c, residue); }); + + if (use_reduction) { + copy_if(tiled_r2g_red, pred, tRG_rOut_epi, tRG_gOut_epi); + } + else { + copy_if(tiled_r2g_stg, pred, tRG_rOut_epi, tRG_gOut_epi); + } + } + }; + + template + static constexpr auto get_reduction_op() + { + using namespace cute; + + // For now only support red.add + if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) { + if constexpr (MaxVecSize % 8 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize % 4 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize % 2 == 0) { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } + } + + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto index_read = [index = params_ptr->ptr_index[l], divmod = params_ptr->index_divmod](auto i){ return divmod.rem(index[i]); }; + Tensor mOut = cutlass::util::make_gather_tensor(params_ptr->ptr_out, make_shape(M,N,Int<1>{}), params_ptr->dOut, index_read); // (M,N,_1) + Tensor gOut = local_tile(mOut, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gOut_epi = flat_divide(gOut, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor mIdx = make_tensor(params_ptr->ptr_index[l], make_shape(M,N,Int<1>{}), StrideIndex{}); // (M,N,_1) + Tensor gIdx = local_tile(mIdx, take<0,2>(args.tile_shape_mnk), make_coord(m,n,Int<0>{})); // (CTA_M,CTA_N) + Tensor gIdx_epi = flat_divide(gIdx, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor cD_epi = flat_divide(args.cD, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor tC_gOut = sm90_partition_for_epilogue(gOut, args.epi_tile, args.tiled_copy, args.thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tC_rOut = make_tensor(take<0,3>(shape(tC_gOut))); // (CPY,CPY_M,CPY_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + + // Vectorization must not exceed alignment and also the number of values per thread in the tile + int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); + int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); + int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + + // Choose the largest available red.global op and an st.global op with matching vectorization + using CopyOpR2GRed = decltype(get_reduction_op()); + using CopyOpR2GStg = UniversalCopy::NumValSrc * sizeof_bits_v>>; + + auto make_tiled_r2g = [&](auto copy_op) + { + using CopyAtomR2G = Copy_Atom; + constexpr int VecSize = CopyAtomR2G::NumValSrc; + if constexpr (cutlass::gemm::detail::is_k_major()) { + constexpr int ThreadsMajor = size<1>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + constexpr int ThreadsMajor = size<0>(args.epi_tile) / VecSize; + constexpr int ThreadsMinor = NumThreads / ThreadsMajor; + return make_tiled_copy(CopyAtomR2G{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }; + + auto tiled_r2g_red = make_tiled_r2g(CopyOpR2GRed{}); + auto tiled_r2g_stg = make_tiled_r2g(CopyOpR2GStg{}); + + // Sanity checks - since we will be using one tiled copy with tensors partitioned with the other tiled copy, + // ensure they have matching layouts/tilers + using TiledR2GRed = decltype(tiled_r2g_red); + using TiledR2GStg = decltype(tiled_r2g_stg); + static_assert(typename TiledR2GRed::AtomLayoutSrc{} == typename TiledR2GStg::AtomLayoutSrc{}, "Mismatching AtomLayoutSrc"); + static_assert(typename TiledR2GRed::AtomLayoutDst{} == typename TiledR2GStg::AtomLayoutDst{}, "Mismatching AtomLayoutDst"); + static_assert(typename TiledR2GRed::TiledLayout_TV{} == typename TiledR2GStg::TiledLayout_TV{}, "Mismatching TiledLayout_TV"); + static_assert(typename TiledR2GRed::Tiler_MN{} == typename TiledR2GStg::Tiler_MN{}, "Mismatching Tiler_MN"); + + auto thread_r2g = tiled_r2g_red.get_slice(args.thread_idx); + Tensor tRG_gOut = thread_r2g.partition_D(gOut_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + Tensor tRG_cD = thread_r2g.partition_D(cD_epi); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + auto residue_cD = args.residue_cD; + + // If shape_override is set, adjust residue_cD to change predication. + // This is used to support fused slicing (where the output tensor is smaller than problem shape) + if (params_ptr->shape_override >= 0) { + get(residue_cD) += params_ptr->shape_override - get(args.problem_shape_mnkl); + } + + auto args_tuple = make_tuple( + cute::move(tC_rOut), + tiled_r2s, + tRG_gOut, + tRG_cD, + tiled_r2g_red, + tiled_r2g_stg, + params_ptr->use_reduction, + args.thread_idx, + residue_cD); + + return ConsumerStoreCallbacks(std::move(args_tuple)); + } +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerColBias + : ScaledAcc +{ + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerRowBiasPerColScaleScatter + : ScaledAccPerRowBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +template< + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct ScaledAccPerColBiasPerRowScaleScatter + : ScaledAccPerColBias +{ + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOut; + static constexpr int AlignmentAux = AlignmentOutput; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = alpha * acc + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerColBiasPtrArray = + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_0,_1,int64_t>, 1>, // scale + Sm90ScaledAccPerRowBiasPtrArray // alpha * acc + bias + > + >; + +template< + class CtaTileShapeMNK, + class EpilogueTile, + class StrideOutput, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScale = ElementCompute, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / cute::sizeof_bits_v, + int AlignmentOutput = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray = + Sm90EVT, // scatter store + Sm90EVT, // scale * (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_1,_0,int64_t>, 1>, // scale + Sm90ScaledAccPerColBiasPtrArray // alpha * acc + bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerRowBiasPerColScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerRowBiasPerColScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_0,_1,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out{}; // output tensor pointer + StrideOutput dOut{}; // output tensor stride + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; // use reduction or regular store + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class GmemLayoutTagOut, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScale, + class ElementScalar, + int AlignmentBias, + int AlignmentOutput, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::ScaledAccPerColBiasPerRowScaleScatter, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + > { + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + + using Impl = Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray< + CtaTileShapeMNK, + EpilogueTile, + StrideOutput, + SmemLayoutAtom, CopyOpR2S, + ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar, + AlignmentBias, AlignmentOutput, RoundStyle + >; + using Operation = fusion::ScaledAccPerColBiasPerRowScaleScatter< + GmemLayoutTagOut, + ElementOutput, + ElementCompute, + ElementBias, + ElementScale, + ElementScalar, + AlignmentBias, + AlignmentOutput, + RoundStyle>; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* const* bias_ptr{}; + StrideBias dBias{}; + + using StrideScale = Stride<_1,_0,int64_t>; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + // Nested args not usable due to a compiler bug with constexpr evaluation + // using ScatterArguments = typename Sm90ScatterPtrArray::Arguments; + // ScatterArguments scatter{}; + + ElementOutput* ptr_out = nullptr; + StrideOutput dOut = {}; + int const* const* ptr_index{}; // per-group pointer to the scatter index + int index_modulo{}; // modulo used to transform the index before store + int shape_override = -1; // override value for contiguous output tensor mode + bool use_reduction = true; + + operator typename Impl::Arguments() const { + return + { // unary op: reduce(scale * (beta * C + (alpha * acc))) + { // binary op: scale * (beta * C + (alpha * acc)) + { scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end binary op + {} // binary args: multiply + }, // end binary op + //scatter // unary args: reduce + { ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction } + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; + +}; + +} // namespace cutlass::epilogue::fusion + +// clang-format on diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 14ba601b39..2cc10e382b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -46,6 +46,7 @@ namespace tkc = tensorrt_llm::cutlass_extensions; namespace tensorrt_llm { namespace kernels { namespace cutlass_kernels { +using namespace cute; template 2 && arch::kMinComputeCapability < 80) { // Multistage only supported on Ampere std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89) { // Multistage only supported on Ampere std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { // FP8 activation type only supported on Ada+ GPUs std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][filter_and_run_mixed_gemm] " + err_msg); } else { generic_mixed_gemm_kernelLauncher() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); @@ -362,17 +369,17 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, break; case tkc::CutlassTileConfig::Undefined: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); break; case tkc::CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " "already been set by " "heuristic."); break; default: throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " + "[TensorRT LLM Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " "type GEMM."); break; } @@ -380,7 +387,7 @@ void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, // This is not a limitation in CUTLASS. We just do not need to support this case. std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier."; - throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_to_cutlass] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][dispatch_gemm_to_cutlass] " + err_msg); } } @@ -388,6 +395,7 @@ template CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); int device{-1}; tk::check_cuda_error(cudaGetDevice(&device)); sm_ = tk::getSMVersion(); @@ -398,7 +406,9 @@ CutlassFpAIntBGemmRunner CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() {} + OutputType>::~CutlassFpAIntBGemmRunner() { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} template @@ -414,6 +424,7 @@ void CutlassFpAIntBGemmRunner< tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if (sm_ >= 75 && sm_ < 80) { dispatch_gemm_to_cutlass( @@ -429,7 +440,7 @@ void CutlassFpAIntBGemmRunner< ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) if constexpr (cutlass::platform::is_same::value) { throw std::runtime_error( - "[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada " + "[TensorRT LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] INT4xFP8 GEMM for Ada " "needs " "CUDA>=12.4"); } @@ -442,13 +453,13 @@ void CutlassFpAIntBGemmRunner< static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, "ScaleZeroType must be half for activation=fp8"); - sm90_dispatch_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_gemm_to_cutlass( A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else { throw std::runtime_error( - "[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for " + "[TensorRT LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for " "CUTLASS mixed type " "GEMM"); } @@ -465,6 +476,7 @@ void CutlassFpAIntBGemmRunner( @@ -487,6 +499,7 @@ void CutlassFpAIntBGemmRunner((ActivationType const*)A, (WeightType const*)B, (ScaleZeroType const*)weight_scales, nullptr, nullptr, @@ -519,6 +534,7 @@ void CutlassFpAIntBGemmRunner::getWorkspaceSize(int const m, int const n, int const k) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // For Hopper, we have to allocate large memory size in case for stream-K if (sm_ == 90) { // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h index a81fffde9d..e01dbd279c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -26,7 +26,7 @@ namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { +namespace cutlass_kernels_oss { namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -43,6 +43,7 @@ void sm90_dispatch_epilogue_schedules( ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.epilogue_schedule) { case tkc::EpilogueScheduleType::AUTO: using EpilogueScheduleType = @@ -57,7 +58,7 @@ void sm90_dispatch_epilogue_schedules( break; default: throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule " "config is invalid for " "mixed " "type GEMM."); @@ -105,6 +106,8 @@ void sm90_dispatch_mainloop_schedules( ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + constexpr bool tile_shapes_supported = are_tile_shapes_supported(); if constexpr (tile_shapes_supported) { @@ -122,7 +125,7 @@ void sm90_dispatch_mainloop_schedules( break; default: throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule " "config is invalid " "for " "mixed type GEMM."); @@ -130,7 +133,7 @@ void sm90_dispatch_mainloop_schedules( } } else { throw std::runtime_error( - "[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and " + "[TensorRT LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and " "Cluster shapes for " "mixed type GEMM."); } @@ -146,6 +149,7 @@ void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (gemm_config.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: sm90_dispatch_mainloop_schedules::type; if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v< - cutlass::arch::Sm90, CTAShape, ClusterShape, ActivationType>) { + cutlass::arch::Sm90, CTAShape, ClusterShape, false, ActivationType>) { using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter::type; @@ -192,7 +195,7 @@ void sm90_generic_mixed_gemm_kernelLauncher( int cta_shape_k = cute::size<2>(TileShape{}); if (group_size % cta_shape_k != 0) { std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner]" + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner]" + err_msg); } if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { @@ -244,7 +247,7 @@ void sm90_generic_mixed_gemm_kernelLauncher( Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { - TLLM_LOG_ERROR("[TensorRT-LLm Error][fpA_intB Runner] given workspace size insufficient."); + TLLM_LOG_ERROR("[TensorRT LLM Error][fpA_intB Runner] given workspace size insufficient."); } auto can_implement = gemm.can_implement(args); @@ -252,25 +255,25 @@ void sm90_generic_mixed_gemm_kernelLauncher( std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); std::cout << err_msg << std::endl; - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } auto init_status = gemm.initialize(args, workspace, stream); if (init_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } auto run_status = gemm.run(stream); if (run_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); } } else { std::stringstream ss; - ss << "[TensorRT-LLm Error][fpA_intB Runner] Config (" << (int64_t)cute::size<0>(CTAShape{}) + ss << "[TensorRT LLM Error][fpA_intB Runner] Config (" << (int64_t)cute::size<0>(CTAShape{}) << "," << (int64_t)cute::size<1>(CTAShape{}) << "," << (int64_t)cute::size<2>(CTAShape{}) << ") (" << (int64_t)cute::size<0>(ClusterShape{}) << "," << (int64_t)cute::size<1>(ClusterShape{}) << "," << (int64_t)cute::size<2>(ClusterShape{}) @@ -281,12 +284,12 @@ void sm90_generic_mixed_gemm_kernelLauncher( #else // COMPILE_HOPPER_TMA_GEMMS throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing " + "[TensorRT LLM Error][fpA_intB Runner] Please recompile with support for hopper by passing " "90-real as an arch " "to build_wheel.py."); #endif // COMPILE_HOPPER_TMA_GEMMS } -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h index da4be7c179..9d493b8ef0 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl index 9a9ecafcd6..52355f34c7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -1,28 +1,30 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include -#include - -#include - #include "cute/tensor.hpp" #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh" +#include "tensorrt_llm/common/cudaUtils.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { template void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, @@ -32,10 +34,10 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy) { - constexpr auto activation_type = fused_moe::EpilogueRouting(true); + constexpr auto activation_type = fused_moe_oss::EpilogueRouting(true); using GemmType = - fused_moe::Fused_Moe_Kernel_sm80; + fused_moe_oss::Fused_Moe_Kernel_sm80; // make sure GPU has enough resources.. if (kernel_occupancy != nullptr) { @@ -49,7 +51,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe tensorrt_llm::common::check_cuda_error(cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); tensorrt_llm::common::check_cuda_error( - cudaFuncGetAttributes(&attr, fused_moe::run_global)); + cudaFuncGetAttributes(&attr, fused_moe_oss::run_global)); if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { // This should mean that // cudaFuncSetAttribute(cutlass::Kernel, @@ -62,11 +64,12 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe int max_active_blocks = -1; tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); + &max_active_blocks, fused_moe_oss::run_global, GemmType::kThreadCount, + smem_size)); *kernel_occupancy = max_active_blocks; return; } - int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); + int occupancy = std::min(2, fused_moe_oss::fused_gemm_maximum_active_blocks()); int const threadblock_count = multi_processor_count * occupancy; TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); @@ -80,7 +83,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe auto params = GemmType::to_underlying_arguments(args); if (GemmType::kSmemSize >= (48 << 10)) { cudaError_t result = - cudaFuncSetAttribute(fused_moe::run_global, + cudaFuncSetAttribute(fused_moe_oss::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + @@ -88,9 +91,9 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe } dim3 grid(params.threadblock_count, 1, 1); dim3 block(GemmType::kThreadCount); - fused_moe::run_global<<>>(params); + fused_moe_oss::run_global<<>>(params); auto result = cudaGetLastError(); TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int)(result)); } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h index badc07b574..ae2ad222b3 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h @@ -18,16 +18,19 @@ #include -#include "moe_gemm_kernels.h" - -namespace tensorrt_llm::kernels::cutlass_kernels { +#include "../../include/moe_gemm_kernels.h" +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; // Keep in sync with the signature generated by generate_kernels.py -template + typename ClusterShape, bool IsMXFPX, bool DYNAMIC_CGA, bool BIAS, bool SwapAB> void tma_warp_specialized_generic_moe_gemm_kernelLauncher( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, - cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); + cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size, + cute::Shape dynamic_cluster_shape, + cute::Shape fallback_cluster_shape); -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index db5788bfdd..e61aad03cb 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -697,7 +697,7 @@ using namespace cutlass::epilogue; TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(init_status))); \ - auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ + auto run_status = gemm.run(stream, nullptr, tensorrt_llm::common::getEnvEnablePDL()); \ TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ "Failed to run cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(run_status))); \ diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h index 16ebddca32..91d12ef0e7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h @@ -1,32 +1,39 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include +#include "../../include/moe_gemm_kernels.h" #include "cutlass_extensions/gemm_configs.h" #include "cutlass_extensions/weight_only_quant_op.h" -#include "moe_gemm_kernels.h" namespace tensorrt_llm { namespace kernels { -namespace cutlass_kernels { - +namespace cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; template void sm90_generic_mixed_moe_gemm_kernelLauncher( - GroupedGemmInput inputs, + tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput + inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size); -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index e28cb7b129..8f4d2f7630 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -1,13 +1,17 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #ifdef __GNUC__ // Check if the compiler is GCC or Clang @@ -44,28 +48,30 @@ #pragma GCC diagnostic pop #endif // __GNUC__ +#include "moe_gemm_tma_ws_mixed_input_launcher.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h" +namespace tensorrt_llm { +namespace kernels { +namespace cutlass_kernels_oss { +using namespace tensorrt_llm::kernels::cutlass_kernels; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; using namespace cute; -namespace tensorrt_llm { -namespace kernels { -namespace cutlass_kernels { - template void sm90_generic_mixed_moe_gemm_kernelLauncher( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -181,40 +187,36 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher( hw_info.device_id = 0; hw_info.sm_count = sm_count_; - if (workspace_size != nullptr) { - const Args args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr}, - {reinterpret_cast(hopper_inputs.ptr_b), hopper_inputs.stride_b, - reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, - reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, - {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), - hopper_inputs.stride_c, reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, - hw_info}; - *workspace_size = gemm.get_workspace_size(args); - return; - } - - assert(group_size == int(inputs.groupwise_quant_group_size)); arguments = Args{ cutlass::gemm::GemmUniversalMode::kGrouped, {inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr}, - {reinterpret_cast(hopper_inputs.ptr_b), hopper_inputs.stride_b, - reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, + {reinterpret_cast(hopper_inputs.ptr_weight), + reinterpret_cast(hopper_inputs.stride_weight), + reinterpret_cast(hopper_inputs.ptr_act), + reinterpret_cast(hopper_inputs.stride_act), reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, - {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, - reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), - hopper_inputs.default_epilogue.stride_d}, + reinterpret_cast(hopper_inputs.int4_groupwise_params.stride_s_a), group_size}, + {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), + reinterpret_cast(hopper_inputs.stride_c), + reinterpret_cast(hopper_inputs.ptr_d), + reinterpret_cast(hopper_inputs.stride_d)}, hw_info}; + assert(group_size == int(inputs.groupwise_quant_group_size)); + if (workspace_size != nullptr) { + *workspace_size = gemm.get_workspace_size(arguments); + return; + } + if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) { TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.", gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size); } + // This is not initialized during workspace size calculation so check after + TLLM_CHECK_WITH_INFO(hopper_inputs.swap_ab, + "swap_ab must be true for mixed dtype WS grouped GEMM"); + auto can_implement = gemm.can_implement(arguments); if (can_implement != cutlass::Status::kSuccess) { std::string err_msg = "mixed dtype WS grouped cutlass kernel will fail for params. Error: " + @@ -239,6 +241,6 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher( return; } -} // namespace cutlass_kernels +} // namespace cutlass_kernels_oss } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu index 1a350efc15..1072cdd1fa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu index 3d020f2618..3e3db70369 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { -#if defined(ENABLE_BF16) && defined(ENABLE_FP4) +#ifdef ENABLE_BF16 template class MoeGemmRunner<__nv_bfloat16, __nv_fp4_e2m1, __nv_bfloat16>; #endif } // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu index 8fd27b4c3f..da1adbac53 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu index 98ec5e7a64..b10a7f6713 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu index 94ed59c0a6..cbf13d5f6f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_BF16 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu index a3af6d6c8a..aba083585f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu index c4533161bd..91cd9413b8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -14,10 +14,8 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { -#if defined(ENABLE_FP4) template class MoeGemmRunner; -#endif -} // namespace tensorrt_llm::kernels::cutlass_kernels +} diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu index 9a464b8311..d216bed89c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu index 92159da7ed..b9bdae53ac 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu index d26e9609fd..747f9b29e8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { template class MoeGemmRunner; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu index a0137fd1c6..a8c11e0692 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu index c00b77dbc1..6bc740c5fa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu index 7235cb5119..08b7ce1930 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu index 01a096b526..7dbf9f6265 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu @@ -1,16 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * - * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual - * property and proprietary rights in and to this material, related - * documentation and any modifications thereto. Any use, reproduction, - * disclosure or distribution of this material and related documentation - * without an express license agreement from NVIDIA CORPORATION or - * its affiliates is strictly prohibited. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h" +#include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index a0cb7777ca..6318992b73 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -67,7 +67,7 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { // ============================= Variable batched Gemm things =========================== template ::value || cutlass::platform::is_same::value || -#if defined(ENABLE_FP4) cutlass::platform::is_same::value || -#endif cutlass::platform::is_same::value); static_assert(arch::kMinComputeCapability < 90, @@ -106,9 +104,10 @@ struct genericMoeGemmKernelLauncher { // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if // necessary. - using ElementType = typename TllmToCutlassTypeAdapter::type; - using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + using ElementType = typename cutlass_kernels::TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = + typename cutlass_kernels::TllmToCutlassTypeAdapter::type; + using CutlassWeightType = typename cutlass_kernels::TllmToCutlassTypeAdapter::type; if (!inputs.use_fused_moe) { // We need separate config for each architecture since we will target different tensorcore // instructions. For float, we do not target TCs. @@ -213,9 +212,9 @@ struct genericMoeGemmKernelLauncher { // support fp16 or // bf16) { - sm80_generic_fused_moe_gemm_kernelLauncher( + tensorrt_llm::kernels::cutlass_kernels_oss::sm80_generic_fused_moe_gemm_kernelLauncher< + ElementType, CutlassWeightType, ThreadblockShape::kM, ThreadblockShape::kN, + ThreadblockShape::kK, Stages, EpilogueTag>( reinterpret_cast(inputs.A), reinterpret_cast(inputs.B), reinterpret_cast(inputs.biases), inputs.bias_is_broadcast, @@ -254,18 +253,19 @@ static void dispatch(GroupedGemmInput= 80) && (!isFp8 || std::is_same_v) && !isFp4) { // dispatch for quant op type - auto* launcher = kernels::cutlass_kernels::genericMoeGemmKernelLauncher< + auto* launcher = tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< T, WeightType, GemmOutputType, Arch, cutlass::WeightOnlyQuantOp::UNDEFINED, EpilogueTag, ThreadblockShape, WarpShape, Stages>::call; if (!std::is_same_v && inputs.groupwise_quant_group_size > 0) { - launcher = inputs.zeros ? kernels::cutlass_kernels::genericMoeGemmKernelLauncher< - T, WeightType, GemmOutputType, Arch, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, - EpilogueTag, ThreadblockShape, WarpShape, Stages>::call - : kernels::cutlass_kernels::genericMoeGemmKernelLauncher< - T, WeightType, GemmOutputType, Arch, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, EpilogueTag, - ThreadblockShape, WarpShape, Stages>::call; + launcher = inputs.zeros + ? tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< + T, WeightType, GemmOutputType, Arch, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, EpilogueTag, + ThreadblockShape, WarpShape, Stages>::call + : tensorrt_llm::kernels::cutlass_kernels_oss::genericMoeGemmKernelLauncher< + T, WeightType, GemmOutputType, Arch, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, EpilogueTag, + ThreadblockShape, WarpShape, Stages>::call; } launcher(inputs, sm_count_); } else { @@ -519,17 +519,23 @@ void dispatchMoeGemmToCutlass( } } +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss + +namespace tensorrt_llm::kernels::cutlass_kernels { + template std::vector -MoeGemmRunner::getConfigs() const { - return getConfigs(sm_); +MoeGemmRunner::getConfigs( + bool supports_finalize_fusion) const { + return getConfigs(sm_, supports_finalize_fusion); } template std::vector -MoeGemmRunner::getConfigs(int sm) { +MoeGemmRunner::getConfigs(int sm, + bool supports_finalize_fusion) { std::vector candidate_configs = - getTmaWarpSpecializedConfigs(sm); + getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); return candidate_configs; @@ -552,19 +558,21 @@ MoeGemmRunner::getAmpereConfigs(int sm auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || + if (!tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89) || use_wfp4a16) { return {}; } std::vector ampere_configs = - kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, + config_type_param); return ampere_configs; } template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs(int sm) { +MoeGemmRunner::getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; @@ -577,28 +585,32 @@ MoeGemmRunner::getTmaWarpSpecializedCo static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; static constexpr auto fp4_only_flag = - (use_fp4 || use_wfp4afp4) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; + (use_fp4 || use_wfp4afp8) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8fp4_mixed_flag = + use_wfp4afp8 ? CutlassGemmConfig::FP8FP4_MIXED : CutlassGemmConfig::NONE; auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_blackwell | enable_hopper | - fp8_only_flag | fp4_only_flag); + fp8_only_flag | fp4_only_flag | fp8fp4_mixed_flag); TLLM_CHECK_WITH_INFO(!(enable_blackwell && enable_hopper), "Blackwell and hopper flags are mutually exclusive"); + sm = use_wfp4afp8 && sm == 103 ? 100 : sm; if (sm >= 100 && sm < 120 && - !kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { + !tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) { TLLM_LOG_TRACE( "Blackwell is not supported for this configuration, not selecting any TMA WS " "implementations"); return {}; } if ((sm == 120 || sm == 121) && - !kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { + !tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { TLLM_LOG_TRACE( "Blackwell SM120 is not supported for this configuration, not selecting any TMA WS " "implementations"); return {}; } - if (enable_hopper && !kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { + if (enable_hopper && + !tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { TLLM_LOG_TRACE( "Hopper is not supported for this configuration, not selecting any TMA WS implementations"); return {}; @@ -606,6 +618,51 @@ MoeGemmRunner::getTmaWarpSpecializedCo std::vector tma_ws_configs = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + + if (sm == 103 && use_fp4) { + // Explicitly select SM100 as well + auto sm100_configs = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs( + 100, max_split_k, config_type_param); + std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs)); + } + + if (supports_finalize_fusion) { + // Duplicate the configs and set the epilogue fusion type to FINALIZE + auto finalize_configs = tma_ws_configs; + std::transform(finalize_configs.begin(), finalize_configs.end(), + std::back_inserter(tma_ws_configs), [](auto& config) { + config.epilogue_fusion_type = + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + return config; + }); + + // Finalize fusion is only supported for TMA epilogue schedule + tma_ws_configs.erase( + std::remove_if( + tma_ws_configs.begin(), tma_ws_configs.end(), + [](auto& config) { + return config.epilogue_fusion_type == + cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE && + config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM; + }), + tma_ws_configs.end()); + } + + auto swap_ab_configs = tma_ws_configs; + std::transform(swap_ab_configs.begin(), swap_ab_configs.end(), std::back_inserter(tma_ws_configs), + [](auto& config) { + TLLM_CHECK_WITH_INFO(!config.swap_ab, "Swap AB is already set"); + config.swap_ab = true; + return config; + }); + + if (use_w4_groupwise) { + // w4 groupwise implementation requires swap_ab to be true + tma_ws_configs.erase(std::remove_if(tma_ws_configs.begin(), tma_ws_configs.end(), + [](auto& config) { return !config.swap_ab; }), + tma_ws_configs.end()); + } + return tma_ws_configs; } @@ -617,12 +674,15 @@ bool MoeGemmRunner::isTmaWarpSpecializ } template -bool MoeGemmRunner::supportsTmaWarpSpecialized() const { - return (sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || - (sm_ >= 100 && sm_ < 120 && - kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) || - ((sm_ == 120 || sm_ == 121) && - kernels::cutlass_kernels::isValidSM120MOESpecialisation()); +bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { + return (sm == 90 && + tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || + (sm >= 100 && sm < 120 && + tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType>()) || + ((sm == 120 || sm == 121) && + tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } template @@ -677,18 +737,14 @@ void MoeGemmRunner::dispatchToArch( "Hopper configuration provided for non-Hopper architecture"); if (sm_ >= 75 && sm_ < 80) { -#ifdef ENABLE_FP4 if constexpr (!std::is_same_v) { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { TLLM_THROW("FP4 data type is not supported on SM < 90"); } -#else - TLLM_THROW("FP4 data type is not supported on SM < 90"); -#endif } else if (sm_ >= 80 && sm_ < 90) { -#ifdef ENABLE_FP4 if constexpr (!std::is_same_v) { if constexpr (use_fp8 || use_w4afp8) { #if defined(ENABLE_FP8) @@ -696,44 +752,39 @@ void MoeGemmRunner::dispatchToArch( !std::is_same_v, "FP8 GEMM Output not supported"); #endif + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } } else { TLLM_THROW("FP4 data type is not supported on SM < 90"); } -#else - TLLM_THROW("FP4 data type is not supported on SM < 90"); -#endif } else if (sm_ >= 90) { - // For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations. - if constexpr (use_fp8) { + // For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations. + if constexpr (use_fp8 && !use_wfp4afp8) { if (sm_ >= 120) { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); return; } } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< - T, WeightType, EpilogueTag>() && + if constexpr (tensorrt_llm::kernels::cutlass_kernels:: + isValidTmaWarpSpecializedMOESpecialisation() && !use_w4_groupwise) { // We allow both tma warp specialized and SM80 configurations to coexist because for some // cases with small numbers of tokens SM80 is faster. We check here to see which is selected if (inputs.gemm_config.sm_version >= 90) { - bool is_same_sm = inputs.gemm_config.sm_version == sm_; - // gemm_config.sm_version indicates the kernel pipeline, which is always 100 for 100, 103, - // 110 below logging helps confirming the cutlass pipeline matches the device major version - bool is_sm110 = inputs.gemm_config.sm_version == 100 && sm_ == 110; - bool is_sm103 = inputs.gemm_config.sm_version == 100 && sm_ == 103; - // SM120 and SM121 are architecturally identical - bool is_sm120 = (inputs.gemm_config.sm_version == 120) && (sm_ == 120 || sm_ == 121); - TLLM_CHECK_WITH_INFO(is_same_sm || is_sm110 || is_sm103 || is_sm120, + // Check the major version of the SM matches + TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, "Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_); TLLM_CHECK_WITH_INFO(inputs.biases != nullptr || hopper_inputs.ptr_c == nullptr, @@ -746,11 +797,11 @@ void MoeGemmRunner::dispatchToArch( auto select_function = [&]() { switch (hopper_inputs.fusion) { case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< T, WeightType, OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE>; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE: - return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< + return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized< T, WeightType, OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>; case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION: @@ -775,16 +826,16 @@ void MoeGemmRunner::dispatchToArch( "w4afp8 is only supported for TMA warp specialization"); // EpilogueTag is ignored if (inputs.k % 512 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 4>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 256 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 2>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else if (inputs.k % 128 == 0) { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 1>( inputs, hopper_inputs, multi_processor_count_, nullptr); } else { TLLM_THROW("Invalid GEMM K size %d", (int)inputs.k); @@ -796,16 +847,16 @@ void MoeGemmRunner::dispatchToArch( TLLM_CHECK_WITH_INFO(inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization"); // EpilogueTag is ignored - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( + cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass< + T, WeightType, ScaleBiasType, cutlass_extensions::EpilogueOpDefault, 1>( inputs, hopper_inputs, multi_processor_count_, nullptr); return; } #endif // Do Ampere case instead - if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) { + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidAmpereMOESpecialisation< + T, WeightType, EpilogueTag>()) { TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available"); TLLM_CHECK_WITH_INFO(use_w4afp8 || !hopper_inputs.isValid(), "Non-specialized Hopper implementation is being rerouted to fallback " @@ -818,10 +869,12 @@ void MoeGemmRunner::dispatchToArch( "Using SM %d configuration for SM80 fallback implementation", inputs.gemm_config.sm_version); if constexpr (use_fp8) { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { - dispatchMoeGemmToCutlass( + cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } } else { @@ -848,18 +901,21 @@ template ::calcMaxWorkspaceSize( int num_experts) const { if constexpr (use_w4_groupwise) { - return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( + return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( num_experts, multi_processor_count_); } if (!supportsTmaWarpSpecialized()) { return 0; } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< + if constexpr (tensorrt_llm::kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation< T, WeightType>() && !use_w4afp8 && !use_wfp4a16) { - auto configs = getTmaWarpSpecializedConfigs(sm_); + // Finalize fusion may not actually be supported by the kernel, + // if they are not we will catch the error and skip them + auto configs = getTmaWarpSpecializedConfigs(sm_, true); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; - if constexpr (use_wfp4afp4) { + if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } else if (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; @@ -867,17 +923,19 @@ size_t MoeGemmRunner::calcMaxWorkspace size_t max_size = 0; bool has_config = false; for (auto conf : configs) { -#define CALC_SIZE_FUSION(FUSION) \ - do { \ - try { \ - size_t size = calcMaxWorkspaceSizeTmaWarpSpecialized( \ - num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ - max_size = std::max(max_size, size); \ - has_config = true; \ - } catch (tensorrt_llm::common::TllmException const& e) { \ - TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s", \ - e.what()); \ - } \ +#define CALC_SIZE_FUSION(FUSION) \ + do { \ + try { \ + size_t size = \ + cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecialized( \ + num_experts, conf, multi_processor_count_, fpX_block_scaling_type); \ + max_size = std::max(max_size, size); \ + has_config = true; \ + } catch (tensorrt_llm::common::TllmException const& e) { \ + TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s", \ + e.what()); \ + } \ } while (0) CALC_SIZE_FUSION(TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index c764cb6c90..4fd4daa8d1 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -49,6 +49,7 @@ #include #include +#include #include #include "../include/moe_gemm_kernels.h" @@ -59,15 +60,59 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; +template +auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilogue_schedule, + bool dynamic_cga, bool swap_ab) { + auto select_swap_ab = [dynamic_cga, epilogue_schedule](auto swap_ab_t) { + auto select_dynamic_cga = [epilogue_schedule](auto dynamic_cga_t) { + constexpr bool is_block_scaled = + std::is_same_v || std::is_same_v; + if constexpr ((!is_block_scaled || Arch::kMinComputeCapability == 103) && + FUSION != EpilogueFusion::FINALIZE) { + auto func_map = std::array{ + &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>, + &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value> + + }; + bool const tma_epilogue = + epilogue_schedule == cutlass_extensions::EpilogueScheduleType::TMA; + return func_map[tma_epilogue]; + } else { + static_assert(FUSION == EpilogueFusion::FINALIZE || Arch::kMinComputeCapability != 103, + "SM103 should support both epilogue schedules"); + TLLM_CHECK_WITH_INFO( + epilogue_schedule == cutlass_extensions::EpilogueScheduleType::TMA, + "No Smem epilogue schedule is not supported for block scaled types or finalize fusion"); + return &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, + EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>; + } + }; + return dynamic_cga ? select_dynamic_cga(tensorrt_llm::common::ConstBool{}) + : select_dynamic_cga(tensorrt_llm::common::ConstBool{}); + }; + return swap_ab ? select_swap_ab(tensorrt_llm::common::ConstBool{}) + : select_swap_ab(tensorrt_llm::common::ConstBool{}); +} + template -void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmInput hopper_input, - int num_experts, int multi_processor_count, - cudaStream_t stream, int* occupancy, - size_t* workspace_size) { +void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( + TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, + cudaStream_t stream, int* occupancy, size_t* workspace_size) { static_assert( (Arch::kMinComputeCapability == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || @@ -79,15 +124,6 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn TLLM_CHECK_WITH_INFO(workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); - // auto func = hopper_input.ptr_c ? - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper - // : - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; - // TODO Re-enable bias when CUTLASS supports it - if constexpr (Arch::kMinComputeCapability < 90) { TLLM_THROW("Invalid architecture instantiated"); } @@ -98,6 +134,13 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn "build_wheel.py."); } #endif +#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS + else if constexpr (Arch::kMinComputeCapability == 103) { + TLLM_THROW( + "Please recompile with support for blackwell by passing 103-real as an arch to " + "build_wheel.py."); + } +#endif #ifndef COMPILE_BLACKWELL_TMA_GROUPED_GEMMS else if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { TLLM_THROW( @@ -113,39 +156,74 @@ void dispatchMoeGemmSelectBiasTmaWarpSpecialized(TmaWarpSpecializedGroupedGemmIn } #endif else { -#ifdef ENABLE_FP4 - auto getFunc = [&]() { - if constexpr (std::is_same_v && std::is_same_v) { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is the only supported scaling type for WFP4AFP8"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true, - false>; - } else { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is not supported for the selected weight combination"); - return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false, - false>; - } - }; - getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); -#else - TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version"); -#endif + constexpr static bool is_wfp4afp8 = + std::is_same_v && std::is_same_v; + if constexpr (is_wfp4afp8) { + TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, + "MXFPX is the only supported scaling type for WFP4AFP8"); + } else { + TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, + "MXFPX is not supported for the selected weight combination"); + } + + if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { + bool const dynamic_cga = + gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; + bool const swap_ab = hopper_input.swap_ab; + auto cluster_shape = + cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); + auto cluster_shape_cute = cute::Shape{ + std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; + auto cluster_shape_fallback = + cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); + auto cluster_shape_cute_fallback = cute::Shape{ + std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; + + auto selected_func = + getDispatchFunctionForSM100( + gemm_config.epilogue_schedule, dynamic_cga, swap_ab); + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { + using EpilogueSchedule = void; // These are hardcoded in the launcher + constexpr bool dynamic_cga = false; + auto selected_func = + hopper_input.swap_ab + ? kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, true> + : kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, false>; + + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, {}, {}); + } } } -template +template constexpr bool are_tile_shapes_supported_sm100() { + // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. + if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || + cute::size<2>(ClusterShape{}) != 1) { + return false; + } + using namespace cute; - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); + if constexpr (Arch::kMinComputeCapability == 103) { + return std::is_same_v && std::is_same_v && + TileM == 128 && (TileN == 128 || TileN == 256); + } + if constexpr (TileM != 64 && TileM != 128) { return false; } @@ -181,14 +259,13 @@ constexpr bool are_tile_shapes_supported_sm100() { return true; } -template +template constexpr bool are_tile_shapes_supported_sm120() { using namespace cute; if constexpr (cute::size<0>(ClusterShape{}) != 1 || cute::size<1>(ClusterShape{}) != 1 || cute::size<2>(ClusterShape{}) != 1) { return false; } - using CtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // This is the epilogue shape. The MMA shape will be twice this for 2SM constexpr auto TileM = size<0>(CtaShape{}); constexpr auto TileN = size<1>(CtaShape{}); @@ -216,7 +293,7 @@ template constexpr bool are_tile_shapes_supported() { if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - return are_tile_shapes_supported_sm100(); + return are_tile_shapes_supported_sm100(); } else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121) { return are_tile_shapes_supported_sm120(); } @@ -247,14 +324,16 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, size_t* workspace_size) { using namespace cute; + // This uses the fallback cluster shape for sm100 if a dynamic cluster shape is requested. switch (gemm_config.cluster_shape) { #define SHAPE_CASE(M, N, K) \ case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: { \ using ClusterShape = Shape<_##M, _##N, _##K>; \ if constexpr (are_tile_shapes_supported()) { \ - dispatchMoeGemmSelectBiasTmaWarpSpecialized( \ - hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ + dispatchMoeGemmFinalDispatchTmaWarpSpecialized( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ + workspace_size); \ break; \ } else { \ TLLM_THROW( \ @@ -275,7 +354,8 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( #undef SHAPE_CASE default: - TLLM_THROW("Unsupported config %d for MoE gemm.", (int)gemm_config.cluster_shape); + TLLM_THROW("Unsupported cluster shape config %d for MoE gemm.", + (int)gemm_config.cluster_shape); } } // namespace tensorrt_llm @@ -301,15 +381,16 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( workspace_size); \ break; \ } -#define DEFAULT_CASE(SMVERSION) \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ - TLLM_THROW("GEMM config undefined."); \ - break; \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::ChooseWithHeuristic: \ - TLLM_THROW("GEMM config should have already been set by heuristic."); \ - break; \ - default: \ - TLLM_THROW("Unsupported config %d for MoE gemm.", (int)gemm_config.tile_config_sm##SMVERSION); \ +#define DEFAULT_CASE(SMVERSION) \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ + TLLM_THROW("GEMM config undefined."); \ + break; \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::ChooseWithHeuristic: \ + TLLM_THROW("GEMM config should have already been set by heuristic."); \ + break; \ + default: \ + TLLM_THROW("Unsupported tile shape config %d for MoE gemm.", \ + (int)gemm_config.tile_config_sm##SMVERSION); \ break; if (gemm_config.sm_version == 90) { @@ -327,29 +408,29 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( } else { TLLM_THROW("Unsupported SM90 configuration requested"); } - } else if (gemm_config.sm_version == 110) { + } +#if defined(ENABLE_FP4) && defined(COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS) + // Check this before SM100 because we fall back to SM100 if not NVFP4 + else if (gemm_config.sm_version == 103 && std::is_same_v && + std::is_same_v) { if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(100, 64, 64, 128) - SHAPE_CASE(100, 64, 128, 128) - SHAPE_CASE(100, 64, 256, 128) + SHAPE_CASE(103, 128, 128, 128) + SHAPE_CASE(103, 128, 256, 128) - SHAPE_CASE(100, 128, 16, 128) - SHAPE_CASE(100, 128, 32, 128) - SHAPE_CASE(100, 128, 64, 128) - SHAPE_CASE(100, 128, 128, 128) - SHAPE_CASE(100, 128, 256, 128) - - DEFAULT_CASE(100) + DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 } } else { - TLLM_THROW("Unsupported SM110 configuration requested"); + TLLM_THROW("Unsupported SM103 configuration requested"); } - } else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 110) { + } +#endif + else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) SHAPE_CASE(100, 64, 64, 128) SHAPE_CASE(100, 64, 128, 128) SHAPE_CASE(100, 64, 256, 128) @@ -360,13 +441,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( SHAPE_CASE(100, 128, 128, 128) SHAPE_CASE(100, 128, 256, 128) - SHAPE_CASE(100, 256, 64, 128) - SHAPE_CASE(100, 256, 128, 128) - SHAPE_CASE(100, 256, 256, 128) - // SHAPE_CASE(100, 128, 128, 64) // SHAPE_CASE(100, 128, 256, 64) - // SHAPE_CASE(100, 256, 256, 64) DEFAULT_CASE(100) } } else { @@ -404,4 +480,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecialized( return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 3375a60716..8c5a5d45e7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -57,8 +57,10 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -namespace tensorrt_llm::kernels::cutlass_kernels { +namespace tensorrt_llm::kernels::cutlass_kernels_oss { +using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput; +using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; @@ -69,6 +71,7 @@ template inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS switch (inputs.gemm_config.mainloop_schedule) { case tkc::MainloopScheduleType::COOPERATIVE: @@ -120,6 +123,7 @@ template inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); switch (inputs.gemm_config.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: sm90_dispatch_mainloop_schedules inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size) { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually // perform the best for mixed type gemms. -#if defined(ENABLE_FP4) constexpr int Ntile = (std::is_same_v) ? 64 : 128; constexpr int Ktile = (std::is_same_v) ? 128 : 128 * PackedScalesNum / sizeof(T); TLLM_CHECK(sizeof(T) == (std::is_same_v) ? 2 : 1); -#else - constexpr int Ntile = 128; - constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); - TLLM_CHECK(sizeof(T) == 2); -#endif using _Ntile = Int; using _Ktile = Int; + switch (inputs.gemm_config.tile_config_sm90) { case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: sm90_dispatch_moe_mixed_dtype_gemm_config size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; -#ifdef ENABLE_FP4 constexpr int Ktile = (std::is_same_v) ? 256 : 512; -#else - constexpr int Ktile = 512; -#endif using _Ktile = Int; #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS @@ -267,4 +263,4 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_ return count; } -} // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels_oss diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu index f240680c6b..52cd03887b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu @@ -14,43 +14,45 @@ * limitations under the License. */ +#include "../include/moe_gemm_kernels.h" #include "cute/tensor.hpp" #include "cutlass/conv/convolution.h" #include "cutlass/cutlass.h" -#include "moe_gemm_kernels.h" // Order matters here, packed_stride.hpp is missing cute and convolution includes #include "cutlass/util/packed_stride.hpp" #include "tensorrt_llm/common/logger.h" namespace tensorrt_llm::kernels::cutlass_kernels { -std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( +std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( int num_experts, FpXBlockScalingType scaling_type) { size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; - size_t stride_a_size = sizeof(StrideA) * num_experts; - size_t stride_b_size = sizeof(StrideB) * num_experts; - size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + size_t stride_act_size = std::max(sizeof(StrideA), sizeof(StrideB)) * num_experts; + size_t stride_weight_size = std::max(sizeof(StrideA), sizeof(StrideB)) * num_experts; + size_t stride_c_size = std::max(sizeof(StrideC), sizeof(StrideC_T)) * num_experts; + size_t stride_d_size = std::max(sizeof(StrideD), sizeof(StrideD_T)) * num_experts; size_t ptr_buf_size = sizeof(void*) * num_experts; size_t scale_buf_size = sizeof(float*) * num_experts; - size_t sf_a_size = sizeof(ElementSF*) * num_experts; - size_t sf_b_size = sizeof(ElementSF*) * num_experts; - size_t stride_sf_a_size = scaling_type == FpXBlockScalingType::MXFPX - ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts - : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; - size_t stride_sf_b_size = scaling_type == FpXBlockScalingType::MXFPX - ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts - : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; + size_t sf_act_size = sizeof(ElementSF*) * num_experts; + size_t sf_weight_size = sizeof(ElementSF*) * num_experts; + size_t stride_sf_act_size = scaling_type == FpXBlockScalingType::MXFPX + ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts + : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; + size_t stride_sf_weight_size = scaling_type == FpXBlockScalingType::MXFPX + ? sizeof(MXFPXBlockScaledConfig::LayoutSF) * num_experts + : sizeof(NVFP4BlockScaledConfig::LayoutSF) * num_experts; size_t int4_groupwise_problem_shape_size = sizeof(INT4GroupwiseParams::ProblemShapeInt::UnderlyingProblemShape) * num_experts; size_t int4_groupwise_sf_a_size = sizeof(INT4GroupwiseParams::SFA*) * num_experts; size_t int4_groupwise_stride_sf_a_size = sizeof(INT4GroupwiseParams::StrideSFA) * num_experts; + size_t ptr_token_map_size = sizeof(int**) * num_experts; + return std::array{problem_shape_size, - stride_a_size, - stride_b_size, + stride_act_size, + stride_weight_size, stride_c_size, stride_d_size, ptr_buf_size, @@ -58,13 +60,16 @@ std::array TmaWarpSpecializedGroupedGemmInput::workspaceBuffers( ptr_buf_size, ptr_buf_size, scale_buf_size, - sf_a_size, - sf_b_size, - stride_sf_a_size, - stride_sf_b_size, + sf_act_size, + sf_weight_size, + stride_sf_act_size, + stride_sf_weight_size, int4_groupwise_problem_shape_size, int4_groupwise_sf_a_size, - int4_groupwise_stride_sf_a_size}; + int4_groupwise_stride_sf_a_size, + ptr_buf_size, + scale_buf_size, + ptr_token_map_size}; } size_t TmaWarpSpecializedGroupedGemmInput::workspaceSize(int num_experts, @@ -78,7 +83,7 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i size_t gemm_workspace_size, FpXBlockScalingType scaling_type) { auto buffers = workspaceBuffers(num_experts, scaling_type); - std::array pointers{}; + std::array pointers{}; TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); for (int i = 0; i < buffers.size(); i++) { @@ -89,23 +94,23 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i shape_info.num_groups = num_experts; shape_info.problem_shapes = reinterpret_cast(pointers[0]); shape_info.host_problem_shapes = nullptr; - stride_a = reinterpret_cast(pointers[1]); - stride_b = reinterpret_cast(pointers[2]); - stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); + stride_act = reinterpret_cast(pointers[1]); + stride_weight = reinterpret_cast(pointers[2]); + stride_c = reinterpret_cast(pointers[3]); + stride_d = reinterpret_cast(pointers[4]); - ptr_a = reinterpret_cast(pointers[5]); - ptr_b = reinterpret_cast(pointers[6]); + ptr_act = reinterpret_cast(pointers[5]); + ptr_weight = reinterpret_cast(pointers[6]); ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + ptr_d = reinterpret_cast(pointers[8]); alpha_scale_ptr_array = reinterpret_cast(pointers[9]); - fpX_block_scaling_factors_A = reinterpret_cast(pointers[10]); - fpX_block_scaling_factors_B = reinterpret_cast(pointers[11]); + fpX_block_scaling_factors_act = reinterpret_cast(pointers[10]); + fpX_block_scaling_factors_weight = reinterpret_cast(pointers[11]); - fpX_block_scaling_factors_stride_A = pointers[12]; - fpX_block_scaling_factors_stride_B = pointers[13]; + fpX_block_scaling_factors_stride_act = pointers[12]; + fpX_block_scaling_factors_stride_weight = pointers[13]; int4_groupwise_params.shape.problem_shapes = reinterpret_cast(pointers[14]); @@ -114,27 +119,30 @@ void TmaWarpSpecializedGroupedGemmInput::configureWorkspace(int8_t* start_ptr, i int4_groupwise_params.stride_s_a = reinterpret_cast(pointers[16]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast(pointers[17]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast(pointers[18]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast(pointers[19]); + this->gemm_workspace = reinterpret_cast(gemm_workspace); this->gemm_workspace_size = gemm_workspace_size; } -void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams( - void* final_output, float const* router_scales, int64_t const* expert_first_token_offset, - int const* source_token_index, void const* bias, int hidden_size, int num_output_tokens) { +void TmaWarpSpecializedGroupedGemmInput::setFinalizeFusionParams(void* final_output, + int hidden_size, + int num_output_tokens, + bool use_reduction) { fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output = cutlass::make_cute_packed_stride( - FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias = - transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.stride_final_output = + cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, + cute::make_shape(num_output_tokens, hidden_size, 1)); + fused_finalize_epilogue.stride_final_output_transposed = + cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput_T{}, + cute::make_shape(hidden_size, num_output_tokens, 1)); fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; + fused_finalize_epilogue.shape_override = hidden_size; + fused_finalize_epilogue.use_reduction = use_reduction; } std::string TmaWarpSpecializedGroupedGemmInput::toString() const { @@ -142,32 +150,29 @@ std::string TmaWarpSpecializedGroupedGemmInput::toString() const { ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; if (isValid()) { using PrintType = void const*; - ss << "Ptr A: " << (PrintType)ptr_a << " with Stride: " << (PrintType)stride_a << ",\n" - << "Ptr B: " << (PrintType)ptr_b << " with Stride: " << (PrintType)stride_b << ",\n" + ss << "Ptr Act: " << (PrintType)ptr_act << " with Stride: " << (PrintType)stride_act << ",\n" + << "Ptr Weight: " << (PrintType)ptr_weight << " with Stride: " << (PrintType)stride_weight + << ",\n" << "Ptr C: " << (PrintType)ptr_c << " with Stride: " << (PrintType)stride_c << "\n"; ss << "Epilogue Fusion: " << (int)fusion << ",\n"; if (fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) { ss << "Final Output: " << (PrintType)fused_finalize_epilogue.ptr_final_output; ss << " with Stride: " << fused_finalize_epilogue.stride_final_output; ss << ",\nBias: " << (PrintType)fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " - << (PrintType)fused_finalize_epilogue.ptr_expert_first_token_offset; ss << ", Source Map: " << (PrintType)fused_finalize_epilogue.ptr_source_token_index; } else { - ss << "Ptr D: " << (PrintType)default_epilogue.ptr_d; - ss << " with Stride: " << (PrintType)default_epilogue.stride_d; + ss << "Ptr D: " << (PrintType)ptr_d; + ss << " with Stride: " << (PrintType)stride_d; } ss << '\n'; ss << "Alpha scale ptr: " << (PrintType)alpha_scale_ptr_array << "\n"; ss << "FpX Block Scaling Type: " << (int)fpX_block_scaling_type << "\n"; - ss << "Fp4 Block Scaling Factors A: " << (PrintType)fpX_block_scaling_factors_A - << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_A << "\n"; - ss << "Fp4 Block Scaling Factors B: " << (PrintType)fpX_block_scaling_factors_B - << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_B << "\n"; + ss << "Fp4 Block Scaling Factors Act: " << (PrintType)fpX_block_scaling_factors_act + << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_act << "\n"; + ss << "Fp4 Block Scaling Factors Weight: " << (PrintType)fpX_block_scaling_factors_weight + << ", with Stride: " << (PrintType)fpX_block_scaling_factors_stride_weight << "\n"; ss << "Gemm Workspace: " << (PrintType)gemm_workspace << ", with Size: " << gemm_workspace_size << "\n"; } diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index d2f742d3d0..33a9fc8530 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -16,9 +16,9 @@ #pragma once +#include "../include/moe_gemm_kernels.h" #include "cutlass/arch/mma_sm90.h" #include "cutlass_extensions/epilogue_helpers.h" -#include "moe_gemm_kernels.h" #ifdef ENABLE_FP4 #include @@ -33,10 +33,11 @@ template constexpr bool isValidSM120MOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice - return cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - cutlass::platform::is_same::value && - Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + return ((cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value)) && + cutlass::platform::is_same::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -51,8 +52,7 @@ constexpr bool isValidBlackwellMOESpecialisation() { return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) && - cutlass::platform::is_same::value && - Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; + cutlass::platform::is_same::value; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -68,14 +68,8 @@ constexpr bool isValidHopperMOESpecialisation() { return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value) || -#ifdef ENABLE_FP4 (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value && - !cutlass::platform::is_same::value) -#else - false -#endif - ) - + !cutlass::platform::is_same::value)) #ifdef ENABLE_FP4 && !cutlass::platform::is_same::value #endif @@ -92,7 +86,8 @@ template constexpr bool isValidTmaWarpSpecializedMOESpecialisation() { // Check at least one of the implementations are valid - return isValidBlackwellMOESpecialisation() || + return isValidSM120MOESpecialisation() || + isValidBlackwellMOESpecialisation() || isValidHopperMOESpecialisation(); } From 307fe30a2daf466cc4048e5775ea2682222c3af6 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 13 Oct 2025 21:08:17 +0000 Subject: [PATCH 07/17] fix compilation errors in cutlass_fused_moe_kernels.cuh --- .../cutlass_fused_moe_instantiation.cu | 4 +- .../cutlass_fused_moe_kernels.cuh | 12 +- .../cutlass_kernels/cutlass_heuristic.h | 10 +- .../include/moe_gemm_kernels.h | 3 + .../cutlass_kernels/include/moe_kernels.h | 156 ++++++++++-------- .../launchers/moe_gemm_tma_ws_launcher.inl | 2 +- 6 files changed, 105 insertions(+), 82 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index 50dfcf78b9..5427dbdac4 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -17,6 +17,7 @@ #include "cutlass_fused_moe_kernels.cuh" #include "moe_kernels.h" +namespace tensorrt_llm::kernels::cutlass_kernels { template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 @@ -53,5 +54,4 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _ template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>; #endif #endif -} -; // namespace tensorrt_llm::kernels::cutlass_kernels +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 162e38bc65..67c1c62a74 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -984,7 +984,7 @@ __device__ auto quantizePackedFPXValue( cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED); + QuantizationSFLayout::SWIZZLED_128x4); // Do the conversion and set the output and scaling factor auto func = [&]() { @@ -1026,7 +1026,7 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, cvt_quant_get_sf_out_offset( std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED); + QuantizationSFLayout::SWIZZLED_128x4); if (sf_out) { if (input_sf) { if (swizzled_input_sf) { @@ -1036,7 +1036,7 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED); + QuantizationSFLayout::SWIZZLED_128x4); *sf_out = *sf_in; } else { auto const sf_in = @@ -3870,8 +3870,8 @@ CutlassMoeFCRunner:: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, - UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, bool enable_pdl, - float const* router_scales, int const* permuted_row_to_unpermuted_row, + UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -4722,7 +4722,7 @@ size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tactic, char* workspace_ptr_char, void const* expert_weights, - cudaStream_t const& stream) { + bool enable_pdl, cudaStream_t const& stream) { int64_t expanded_num_tokens = original_num_tokens * mK; int64_t num_experts_per_node = mNumExpertsPerNode; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h index f7ea83cdb0..80c024bdb7 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h @@ -24,7 +24,8 @@ namespace tensorrt_llm { namespace kernels { namespace cutlass_kernels { -template +template struct should_filter_tma_warp_specialized_gemm_problem_shape { #ifdef FAST_BUILD using SupportedCtaShape = @@ -32,15 +33,16 @@ struct should_filter_tma_warp_specialized_gemm_problem_shape { using SupportedCgaShape = cute::Shape; constexpr static bool value = !cute::is_same_v || - !cute::is_same_v; + !cute::is_same_v || DYNAMIC_CGA; #else constexpr static bool value = false; #endif }; -template +template constexpr static bool should_filter_tma_warp_specialized_gemm_problem_shape_v = should_filter_tma_warp_specialized_gemm_problem_shape::value; + DYNAMIC_CGA, ActivationType>::value; std::vector get_candidate_configs( int sm, int const max_split_k, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 916e8ab78f..aa57faca11 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -216,6 +216,9 @@ struct TmaWarpSpecializedGroupedGemmInput { uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; + // Whether to enable PDL (Programmatic Dependent Launch). + bool enable_pdl{}; + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 2ad23d7f7d..1e92fdbeba 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "cutlass/gemm/gemm.h" #include "moe_gemm_kernels.h" #include "tensorrt_llm/common/assert.h" @@ -219,6 +221,8 @@ struct MOEParallelismConfig { } }; +enum class MoeGemmId : int { Undefined = 0, GEMM_1, GEMM_2 }; + struct QuantParams { // Int weight only quantization params struct { @@ -426,14 +430,15 @@ class CutlassMoeFCRunnerInterface { bool use_awq) = 0; virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; - virtual std::vector getTactics() = 0; + virtual std::vector getTactics(MoeGemmId gemm_id) = 0; virtual void runMoe(void const* input_activations, void const* input_sf, - int const* token_selected_experts, float const* token_final_scales, - void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationParams fc1_activation_type, void const* fc2_expert_weights, - void const* fc2_expert_biases, QuantParams quant_params, - int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, + bool const swizzled_input_sf, int const* token_selected_experts, + float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, + void const* fc2_expert_weights, void const* fc2_expert_biases, + QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, @@ -459,26 +464,24 @@ class CutlassMoeFCRunnerInterface { int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) = 0; - virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, - int64_t const* const expert_first_token_offset, - TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, - void const* const fc2_expert_weights, void const* const fc2_expert_biases, - void const* const fc2_int_scales, float const* const fc2_fp8_dequant, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, - QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, - int const* const unpermuted_row_to_permuted_row, - int const* permuted_row_to_unpermuted_row, - int const* const token_selected_experts, - int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, - int64_t const experts_per_token, float const** alpha_scale_ptr_array, - bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, - cudaStream_t stream, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, - bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids, bool enable_pdl) = 0; + virtual void gemm2( + void const* const input, void* const gemm_output, void* const final_output, + int64_t const* const expert_first_token_offset, + TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, + void const* const fc2_expert_weights, void const* const fc2_expert_biases, + void const* const fc2_int_scales, float const* const fc2_fp8_dequant, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, + QuantParams quant_params, float const* const token_topk_unpermuted_scales, + float const* const token_topk_permuted_scales, + int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, + int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, + int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, + int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, + void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, + cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) = 0; virtual std::pair computeStridesTmaWarpSpecializedDispatch( @@ -490,7 +493,8 @@ class CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) = 0; virtual std::pair @@ -509,13 +513,13 @@ class CutlassMoeFCRunnerInterface { virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; - bool use_deterministic_hopper_reduce_ = false; + bool use_fused_finalize_ = true; }; // Assumes inputs activations are row major. Weights need to be preprocessed by // th_op/weight_quantize.cc . Nested in a class to avoid multiple calls to cudaGetDeviceProperties // as this call can be expensive. Avoid making several duplicates of this class. -template ; -#if defined(ENABLE_FP4) #if defined(ENABLE_BF16) static constexpr bool use_wfp4a16 = std::is_same_v && (std::is_same_v || std::is_same_v); @@ -535,16 +538,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; #endif -#else - static constexpr bool use_wfp4a16 = false; -#endif - #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) && !std::is_same_v; static constexpr bool use_w4afp8 = std::is_same_v && std::is_same_v; + static constexpr bool use_fp8_input = std::is_same_v; static_assert(!std::is_same_v, "Current logic requires backbone type to be >=16-bits"); static_assert(!std::is_same_v, @@ -601,25 +601,26 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { gemm2_config_ = std::move(gemm2_config); } - std::vector getTactics() override { - return moe_gemm_runner_.getConfigs(); + std::vector getTactics(MoeGemmId gemm_id) override { + return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); } - static std::vector getTactics(int sm) { + static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return RunnerType::getConfigs(sm); + return RunnerType::getConfigs(sm, + gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); } - void runMoe(void const* input_activations, void const* input_sf, + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts, - int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, + int64_t const hidden_size, int64_t const unpadded_hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, + char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) override; @@ -663,11 +664,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, - float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, - cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, + int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, + void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, + bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, + bool enable_pdl); // Overrides to allow us to forward on to the internal functions with the pointers using the // correct type @@ -710,7 +712,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, @@ -727,10 +730,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static_cast(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, - num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, stream, - parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per, - active_expert_global_ids, enable_pdl); + num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, unpadded_hidden_size, + inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, + fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode, + num_active_experts_per, active_expert_global_ids, enable_pdl); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override { @@ -747,7 +750,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + float const* router_scales, int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized( expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, @@ -758,7 +762,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast(bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), enable_pdl, stream); + reinterpret_cast(gemm2_output), router_scales, + permuted_row_to_unpermuted_row, enable_pdl, stream); } std::pair @@ -789,8 +794,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::pair setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, - int64_t inter_size, int64_t num_experts_per_node, - void const* input_activations_void, + int64_t unpadded_hidden_size, int64_t inter_size, + int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -811,7 +816,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, bool enable_pdl, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, float const* router_scales, + int const* permuted_row_to_unpermuted_row, bool enable_pdl, cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency( TmaWarpSpecializedGroupedGemmInput layout_info1, @@ -844,8 +850,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { } bool mayHaveFinalizeFused() const { - return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 && - !use_deterministic_hopper_reduce_ && !use_w4_groupwise; + return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && + use_fused_finalize_ && !use_w4_groupwise; + } + + static bool mayHaveFinalizeFused(int sm) { + using RunnerType = decltype(moe_gemm_runner_); + return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility @@ -891,7 +902,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int const* const unpermuted_row_to_permuted_row, int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, bool enable_pdl, cudaStream_t stream); @@ -951,14 +963,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { struct GemmProfilerBackend { public: using Config = cutlass_extensions::CutlassGemmConfig; - enum class GemmToProfile { Undefined = 0, GEMM_1, GEMM_2 }; + using GemmToProfile = MoeGemmId; void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype, nvinfer1::DataType wtype, nvinfer1::DataType otype, - int num_experts, int k, int64_t hidden_size, int64_t inter_size, int64_t group_size, - ActivationType activation_type, bool bias, bool use_lora, bool min_latency_mode, - bool need_weights, MOEParallelismConfig parallelism_config, - bool const enable_alltoall) { + int num_experts, int k, int64_t hidden_size, int64_t unpadded_hidden_size, + int64_t inter_size, int64_t group_size, ActivationType activation_type, bool bias, + bool use_lora, bool min_latency_mode, bool need_weights, + MOEParallelismConfig parallelism_config, bool const enable_alltoall) { mInterface = &runner; mGemmToProfile = gemm_to_profile; mDType = dtype; @@ -968,6 +980,7 @@ struct GemmProfilerBackend { mNumExpertsPerNode = num_experts / parallelism_config.ep_size; mK = k; mExpertHiddenSize = hidden_size; + mExpertUnpaddedHiddenSize = unpadded_hidden_size; mExpertInterSize = inter_size; // Already divided by tp_size mGroupSize = group_size; mActivationType = activation_type; @@ -1001,12 +1014,12 @@ struct GemmProfilerBackend { CutlassMoeFCRunnerInterface* mInterface; GemmToProfile mGemmToProfile = GemmToProfile::Undefined; - std::vector mAllTacticsSaved; int mSM{}; int64_t mNumExperts{}; int64_t mNumExpertsPerNode{}; int64_t mK{}; int64_t mExpertHiddenSize{}; + int64_t mExpertUnpaddedHiddenSize{}; int64_t mExpertInterSize{}; int64_t mGroupSize{}; ActivationType mActivationType{}; @@ -1022,7 +1035,11 @@ struct GemmProfilerBackend { // This will be a unique value for every iteration of warmup and actual bench constexpr static int64_t NUM_ROUTING_SAMPLES = 16; - std::array mTmaInputCache; + constexpr static int64_t NUM_FUSION_TYPES = 2; + constexpr static int64_t NUM_SWAP_AB_TYPES = 2; + constexpr static int64_t NUM_WORKSPACES = NUM_FUSION_TYPES * NUM_SWAP_AB_TYPES; + TmaWarpSpecializedGroupedGemmInput mTmaInputCache[NUM_FUSION_TYPES][NUM_SWAP_AB_TYPES] + [NUM_ROUTING_SAMPLES]; QuantParams mQuantParams; bool mBias{}; @@ -1036,6 +1053,7 @@ struct GemmProfilerBackend { void prepareRouting(int num_tokens, char* workspace, bool enable_pdl, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, bool enable_pdl, cudaStream_t stream); }; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index e61aad03cb..db5788bfdd 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -697,7 +697,7 @@ using namespace cutlass::epilogue; TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(init_status))); \ - auto run_status = gemm.run(stream, nullptr, tensorrt_llm::common::getEnvEnablePDL()); \ + auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ "Failed to run cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(run_status))); \ From 76a922019ef4a77f2b382da2764271338a12bb44 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 13 Oct 2025 21:42:35 +0000 Subject: [PATCH 08/17] >gather_tensor.hpp --- .../cutlass_extensions/util/gather_tensor.hpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp index 4ba4fc9f20..5a3b5f2302 100644 --- a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -34,7 +34,7 @@ #include "cute/tensor.hpp" #include "cute/util/print.hpp" -using namespace cute; +namespace cutlass::util { /// Function object that applies an index to its argument template @@ -48,8 +48,8 @@ struct IndexedGather { CUTE_HOST_DEVICE friend void print(IndexedGather const& s) { cute::print("Indexed{"); - print(s.indices_); - print("}"); + cute::print(s.indices_); + cute::print("}"); } Iter indices_; @@ -73,23 +73,23 @@ struct CustomStride { CUTE_HOST_DEVICE friend void print(CustomStride const& s) { cute::print("Custom{"); - print(s.func_); + cute::print(s.func_); cute::print(","); - print(s.stride_); + cute::print(s.stride_); cute::print("}"); } template CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) { - return CustomStride(s.func_, - safe_div(s.stride_, div)); + return CustomStride( + s.func_, cute::safe_div(s.stride_, div)); } // Circumvent the requirement on make_layout that shape and stride are integral template CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) { - return Layout(shape, stride); + return cute::Layout(shape, stride); } Func func_; @@ -98,6 +98,7 @@ struct CustomStride { template CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) { + using namespace cute; // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather // stride auto idx = find_if(stride, [](auto x) { @@ -112,11 +113,13 @@ CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& template CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) { + using namespace cute; Layout matrix_layout = make_identity_layout(shape); auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); } +} // namespace cutlass::util namespace cute { From 96c0ed494241eab4f723fbcb0eb9cde16ce3e88e Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 13 Oct 2025 22:55:44 +0000 Subject: [PATCH 09/17] fix compilation errors --- ...shinfer_cutlass_fused_moe_sm100_binding.cu | 109 +++--- .../include/cutlass_extensions/gemm_configs.h | 358 ++++++++++-------- 2 files changed, 256 insertions(+), 211 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 0bd58e8047..f8e6a5d028 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -72,6 +72,8 @@ class DtypeUtils { default: TVM_FFI_ICHECK(false) << "unsupported data type"; } + + return nvinfer1::DataType::kFLOAT; // supress compiler warning } private: @@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type) << " specified for " << DLDataTypeToString(mActivationDtype); } + + return nullptr; // supress compiler warning }; FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype, @@ -219,7 +223,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { } mProfiler = std::make_shared(); - mAllProfiles = mKernelRunner->getTactics(); + // Get tactics for both GEMM1 and GEMM2, combine them + auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); + auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); + mAllProfiles = gemm1_tactics; + mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end()); TVM_FFI_ICHECK(!mAllProfiles.empty()) << "No valid tactics available for fused moe op with the requested input combination " "Activation: " @@ -361,25 +369,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; + // HACK Define default values for parameters we don't have good values for + bool const swizzled_input_sf = false; // Assume input_sf is not swizzled by default + int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default + bool const use_lora = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM - mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), - token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) - : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), - static_cast(workspace_info.workspace->data), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, - enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, - min_latency_mode, min_latency_params, enable_pdl, stream); + mKernelRunner->runMoe( + input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf, + reinterpret_cast(token_selected_experts->data), + token_final_scales.has_value() + ? reinterpret_cast(token_final_scales.value()->data) + : nullptr, + fc1_expert_weights->data, + fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, + activation_params, fc2_expert_weights->data, + fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, + num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace->data), + output->data, static_cast(workspace_info.src_to_dest_map), parallelism_config, + enable_alltoall, use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( - input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, + input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf, reinterpret_cast(token_selected_experts->data), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value()->data) @@ -388,10 +400,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, activation_params, fc2_expert_weights->data, fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, - num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), - static_cast(workspace_info.workspace), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace), + output->data, static_cast(workspace_info.src_to_dest_map), parallelism_config, false, + use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -530,25 +543,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; + // HACK Define default values for parameters we don't have good values for + bool const swizzled_input_sf_ml = false; // Assume input_sf is not swizzled by default + int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default + bool const use_lora_ml = false; // No lora support yet #ifdef USING_OSS_CUTLASS_MOE_GEMM - mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), - token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) - : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), - static_cast(workspace_info.workspace->data), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, - enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, - min_latency_mode, min_latency_params, enable_pdl, stream); + mKernelRunner->runMoe( + input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml, + reinterpret_cast(token_selected_experts->data), + token_final_scales.has_value() + ? reinterpret_cast(token_final_scales.value()->data) + : nullptr, + fc1_expert_weights->data, + fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, + activation_params, fc2_expert_weights->data, + fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, + num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace->data), + output->data, static_cast(workspace_info.src_to_dest_map), parallelism_config, + enable_alltoall, use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( - input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, + input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml, reinterpret_cast(token_selected_experts->data), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value()->data) @@ -557,10 +574,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, activation_params, fc2_expert_weights->data, fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, - num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), - static_cast(workspace_info.workspace), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace), + output->data, static_cast(workspace_info.src_to_dest_map), parallelism_config, false, + use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -621,19 +639,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { auto activation_dtype = (mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype; activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype; + int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default #ifdef USING_OSS_CUTLASS_MOE_GEMM mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), DtypeUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), - hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, - USE_LORA, min_latency_mode, + hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, + ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config, enable_alltoall); #else mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), DtypeUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), - hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, - USE_LORA, min_latency_mode, + hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, + ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config); #endif diff --git a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 6c5a823e8e..b2301c1a82 100644 --- a/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,12 @@ #include #include #include +#include +#include #include "cute/tensor.hpp" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/tllmException.h" namespace tensorrt_llm { namespace cutlass_extensions { @@ -30,10 +34,10 @@ namespace cutlass_extensions { // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, // SiMT config CtaShape128x128x8_WarpShape64x64x8, @@ -77,77 +81,96 @@ enum class SplitKStyle { // SPLIT_K_PARALLEL // Not supported yet }; -enum class CutlassTileConfigSM90 { +constexpr static int shape_tuple_to_enum(int m, int n, int k) { + assert(m >= 0 && n >= 0 && k >= 0); + assert(m < 1000 && n < 1000 && k < 1000); + return m * 1000000 + n * 1000 + k; +} + +template +constexpr static std::tuple enum_to_shape_tuple(TEnum shape_id_enum) { + static_assert(std::is_enum_v && std::is_same_v, int>, + "TEnum must be an enum with underlying type int"); + auto shape_id = static_cast(shape_id_enum); + assert(shape_id >= 0); + assert(shape_id < (int)1e9); + return std::make_tuple(shape_id / 1000000, (shape_id % 1000000) / 1000, shape_id % 1000); +} + +enum class CutlassTileConfigSM90 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, + CtaShape64x16x128B = shape_tuple_to_enum(64, 16, 128), + CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128), + CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128), + CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128), + CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128), // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, + CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128), + CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128), + CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128), + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128), // CTA configs for M=256 - CtaShape256x128x128B, - CtaShape256x256x128B, + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), + CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128), }; -enum class CutlassTileConfigSM100 { +enum class CutlassTileConfigSM100 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + ChooseWithHeuristic = 1, /* * Grouped GEMM */ // M=64 - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, + CtaShape64x32x128B = shape_tuple_to_enum(64, 32, 128), + CtaShape64x64x128B = shape_tuple_to_enum(64, 64, 128), + CtaShape64x128x128B = shape_tuple_to_enum(64, 128, 128), + CtaShape64x256x128B = shape_tuple_to_enum(64, 256, 128), // M=128 - CtaShape128x8x256B, - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - CtaShape128x128x256B, - CtaShape128x256x256B, + CtaShape128x8x256B = shape_tuple_to_enum(128, 8, 256), + CtaShape128x16x128B = shape_tuple_to_enum(128, 16, 128), + CtaShape128x32x128B = shape_tuple_to_enum(128, 32, 128), + CtaShape128x64x128B = shape_tuple_to_enum(128, 64, 128), + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x256x128B = shape_tuple_to_enum(128, 256, 128), + CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256), + CtaShape128x256x256B = shape_tuple_to_enum(128, 256, 256), // M=256 - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B, + CtaShape256x64x128B = shape_tuple_to_enum(256, 64, 128), + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), + CtaShape256x256x128B = shape_tuple_to_enum(256, 256, 128), }; -enum class CutlassTileConfigSM120 { +// An alias to make the SHAPE_CASE macro work +using CutlassTileConfigSM103 = CutlassTileConfigSM100; + +enum class CutlassTileConfigSM120 : int { // Signals that we should run heuristics do choose a config - Undefined, + Undefined = 0, // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - CtaShape128x128x128B, - CtaShape128x128x64B, - CtaShape256x128x64B, - CtaShape128x256x64B, - CtaShape128x128x256B, - CtaShape256x128x128B, + ChooseWithHeuristic = 1, + + CtaShape128x128x128B = shape_tuple_to_enum(128, 128, 128), + CtaShape128x128x64B = shape_tuple_to_enum(128, 128, 64), + CtaShape256x128x64B = shape_tuple_to_enum(256, 128, 64), + CtaShape128x256x64B = shape_tuple_to_enum(128, 256, 64), + CtaShape128x128x256B = shape_tuple_to_enum(128, 128, 256), + CtaShape256x128x128B = shape_tuple_to_enum(256, 128, 128), }; enum class MainloopScheduleType { @@ -175,115 +198,73 @@ enum class EpilogueScheduleType { AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop // schedule for Hopper. For architectures older than hopper, the epilogue is always // performed by the same thread block as the main loop. + NO_SMEM, + TMA }; -enum class TileShape { - TileShape_64x16x128, - TileShape_64x32x128, - TileShape_64x64x128, - TileShape_64x128x128, - TileShape_64x256x128, - TileShape_64x512x128, - TileShape_128x16x128, - TileShape_128x32x128, - TileShape_128x64x128, - TileShape_128x128x128, - TileShape_128x256x128, - TileShape_256x128x128, - TileShape_256x256x128 +enum class TileShape : int { + Undefined = 0, + TileShape_64x16x128 = shape_tuple_to_enum(64, 16, 128), + TileShape_64x32x128 = shape_tuple_to_enum(64, 32, 128), + TileShape_64x64x128 = shape_tuple_to_enum(64, 64, 128), + TileShape_64x128x128 = shape_tuple_to_enum(64, 128, 128), + TileShape_64x256x128 = shape_tuple_to_enum(64, 256, 128), + TileShape_64x512x128 = shape_tuple_to_enum(64, 512, 128), + TileShape_128x16x128 = shape_tuple_to_enum(128, 16, 128), + TileShape_128x32x128 = shape_tuple_to_enum(128, 32, 128), + TileShape_128x64x128 = shape_tuple_to_enum(128, 64, 128), + TileShape_128x128x128 = shape_tuple_to_enum(128, 128, 128), + TileShape_128x256x128 = shape_tuple_to_enum(128, 256, 128), + TileShape_256x128x128 = shape_tuple_to_enum(256, 128, 128), + TileShape_256x256x128 = shape_tuple_to_enum(256, 256, 128) }; template constexpr auto get_tile_shape() { using namespace cute; - if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) { - return cute::Shape<_64, _16, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) { - return cute::Shape<_64, _32, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) { - return cute::Shape<_64, _64, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) { - return cute::Shape<_64, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) { - return cute::Shape<_64, _256, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) { - return cute::Shape<_64, _512, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) { - return cute::Shape<_128, _16, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) { - return cute::Shape<_128, _32, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) { - return cute::Shape<_128, _64, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) { - return cute::Shape<_128, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { - return cute::Shape<_128, _256, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_256x128x128) { - return cute::Shape<_256, _128, _128>{}; - } else if constexpr (Shape_MNK == TileShape::TileShape_256x256x128) { - return cute::Shape<_256, _256, _128>{}; - } + static_assert(Shape_MNK != TileShape::Undefined, "TileShape is undefined"); + + constexpr auto shape_tuple = enum_to_shape_tuple(Shape_MNK); + return cute::Shape(shape_tuple)>, cute::Int(shape_tuple)>, + cute::Int(shape_tuple)>>{}; } -static auto get_tile_shape_name(TileShape Shape_MNK) { - if (Shape_MNK == TileShape::TileShape_64x16x128) { - return "64x16x128"; - } else if (Shape_MNK == TileShape::TileShape_64x32x128) { - return "64x32x128"; - } else if (Shape_MNK == TileShape::TileShape_64x64x128) { - return "64x64x128"; - } else if (Shape_MNK == TileShape::TileShape_64x128x128) { - return "64x128x128"; - } else if (Shape_MNK == TileShape::TileShape_64x256x128) { - return "64x256x128"; - } else if (Shape_MNK == TileShape::TileShape_64x512x128) { - return "64x512x128"; - } else if (Shape_MNK == TileShape::TileShape_128x16x128) { - return "128x16x128"; - } else if (Shape_MNK == TileShape::TileShape_128x32x128) { - return "128x32x128"; - } else if (Shape_MNK == TileShape::TileShape_128x64x128) { - return "128x64x128"; - } else if (Shape_MNK == TileShape::TileShape_128x128x128) { - return "128x128x128"; - } else if (Shape_MNK == TileShape::TileShape_128x256x128) { - return "128x256x128"; - } else if (Shape_MNK == TileShape::TileShape_256x128x128) { - return "256x128x128"; - } else if (Shape_MNK == TileShape::TileShape_256x256x128) { - return "256x256x128"; +template +static std::string get_tile_shape_name(TEnum Shape_MNK) { + static_assert(std::is_enum_v && std::is_same_v, int>, + "TEnum must be an enum with underlying type int"); + if ((int)Shape_MNK == 0) { + return "undefined"; + } else if ((int)Shape_MNK == 1) { + return "heuristic"; + } else { + auto [m, n, k] = enum_to_shape_tuple(Shape_MNK); + return std::to_string(m) + "x" + std::to_string(n) + "x" + std::to_string(k); } - return "Unknown shape"; } -enum class ClusterShape { - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x4x1, - ClusterShape_4x2x1, - ClusterShape_2x4x1, - ClusterShape_4x4x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 +enum class ClusterShape : int { + Undefined = 0, + ClusterShape_1x1x1 = shape_tuple_to_enum(1, 1, 1), + ClusterShape_2x1x1 = shape_tuple_to_enum(2, 1, 1), + ClusterShape_1x2x1 = shape_tuple_to_enum(1, 2, 1), + ClusterShape_2x2x1 = shape_tuple_to_enum(2, 2, 1), + ClusterShape_1x4x1 = shape_tuple_to_enum(1, 4, 1), + ClusterShape_4x1x1 = shape_tuple_to_enum(4, 1, 1), + ClusterShape_4x2x1 = shape_tuple_to_enum(4, 2, 1), + ClusterShape_2x4x1 = shape_tuple_to_enum(2, 4, 1), + ClusterShape_4x4x1 = shape_tuple_to_enum(4, 4, 1), + ClusterShape_1x8x1 = shape_tuple_to_enum(1, 8, 1), + ClusterShape_8x1x1 = shape_tuple_to_enum(8, 1, 1) }; -static auto get_cluster_shape_name(ClusterShape Shape_MNK) { - if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { - return "1x1x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { - return "2x1x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { - return "1x2x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { - return "2x2x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { - return "1x8x1"; - } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { - return "8x1x1"; +static std::string get_cluster_shape_name(ClusterShape Shape_MNK) { + if (Shape_MNK == ClusterShape::Undefined) { + return "undefined"; + } else { + auto [m, n, k] = enum_to_shape_tuple(Shape_MNK); + return std::to_string(m) + "x" + std::to_string(n) + "x" + std::to_string(k); } - return "Unknown shape"; } template @@ -297,10 +278,22 @@ constexpr auto get_cluster_shape() { return cute::Shape<_1, _2, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { + return cute::Shape<_4, _1, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { return cute::Shape<_1, _8, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return cute::Shape<_8, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x4x1) { + return cute::Shape<_1, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x2x1) { + return cute::Shape<_4, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x4x1) { + return cute::Shape<_2, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x4x1) { + return cute::Shape<_4, _4, _1>{}; + } else { + return cute::Shape<_0, _0, _0>{}; } } @@ -314,7 +307,8 @@ struct CutlassGemmConfig { BLACKWELL = 1u << 4, GROUPED_GEMM = 1u << 5, FP8_ONLY = 1u << 6, - FP4_ONLY = 1u << 7 + FP4_ONLY = 1u << 7, + FP8FP4_MIXED = 1u << 8 }; CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic; @@ -329,10 +323,17 @@ struct CutlassGemmConfig { MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + ClusterShape dynamic_cluster_shape = ClusterShape::Undefined; + ClusterShape fallback_cluster_shape = ClusterShape::Undefined; bool enableCudaKernel = false; int sm_version = 80; // Use 80 as a catch all for <90 bool is_tma_warp_specialized = false; + enum class EpilogueFusionType : int { NONE, FINALIZE }; + + EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE; + bool swap_ab = false; + CutlassGemmConfig() = default; CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, @@ -352,15 +353,24 @@ struct CutlassGemmConfig { sm_version(90), is_tma_warp_specialized(true) {} + // If dynamic_cluster_shape is provided, dynamic CGA will be enabled and cluster_shape will be + // interpreted as whether to use 1 or 2 SM mode, otherwise static cluster shape is used. CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule, - ClusterShape cluster_shape) + ClusterShape cluster_shape, + ClusterShape dynamic_cluster_shape = ClusterShape::Undefined, + ClusterShape fallback_cluster_shape = ClusterShape::Undefined, + int sm_version = 100) : tile_config_sm100(tile_config_sm100), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), - sm_version(100), - is_tma_warp_specialized(true) {} + dynamic_cluster_shape(dynamic_cluster_shape), + fallback_cluster_shape(fallback_cluster_shape), + sm_version(sm_version), + is_tma_warp_specialized(true) { + TLLM_CHECK_WITH_INFO(sm_version >= 100 && sm_version < 120, "Expected SM 10x version"); + } CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule, @@ -373,26 +383,38 @@ struct CutlassGemmConfig { is_tma_warp_specialized(true) {} int getTileConfigAsInt() const { - if (sm_version == 120) return (int)tile_config_sm120; - if (sm_version == 110) return (int)tile_config_sm100; - if (sm_version >= 100) return (int)tile_config_sm100; + if (sm_version == 120 || sm_version == 121) return (int)tile_config_sm120; + if (sm_version >= 100 && sm_version < 120) return (int)tile_config_sm100; if (sm_version == 90) return (int)tile_config_sm90; if (sm_version < 90) return (int)tile_config_sm80; assert(false && "Invalid SM version"); return -1; } + std::string getTileConfigAsName() const { + if (sm_version == 120 || sm_version == 121) return get_tile_shape_name(tile_config_sm120); + if (sm_version >= 100 && sm_version < 120) return get_tile_shape_name(tile_config_sm100); + if (sm_version == 90) return get_tile_shape_name(tile_config_sm90); + if (sm_version < 90) return std::to_string((int)tile_config_sm80); + assert(false && "Invalid SM version"); + return "invalid"; + } + std::string toString() const { std::stringstream tactic; tactic << "Cutlass GEMM Tactic"; if (is_tma_warp_specialized) { assert(sm_version >= 90 && "Invalid cutlass GEMM config"); tactic << "\n\tstyle=TMA Warp Specialized" - << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() - << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsName() + << "\n\tcluster shape ID: " << get_cluster_shape_name(cluster_shape) + << "\n\tdynamic cluster shape ID: " << get_cluster_shape_name(dynamic_cluster_shape) + << "\n\tfallback cluster shape ID: " << get_cluster_shape_name(fallback_cluster_shape) << "\n\tmainloop sched: " << (int)mainloop_schedule << "\n\tepi sched: " << (int)epilogue_schedule - << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false") + << "\n\tepilogue fusion type: " << (int)epilogue_fusion_type + << "\n\tswap_ab: " << (swap_ab ? "true" : "false"); } else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { assert(sm_version < 90 && "Invalid cutlass GEMM config"); @@ -412,22 +434,26 @@ struct CutlassGemmConfig { inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { // clang-format off - if (config.is_tma_warp_specialized) - { - out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() - << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) - << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) - << ", cluster_shape_enum: " << int(config.cluster_shape) - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); - } - else - { - out << "tile_config_enum: " << config.getTileConfigAsInt() - << ", split_k_style_enum: " << int(config.split_k_style) - << ", split_k_factor: " << config.split_k_factor - << ", stages: " << config.stages - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); - } + if (config.is_tma_warp_specialized) + { + out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape) + << ", dynamic_cluster_shape_enum: " << int(config.dynamic_cluster_shape) + << ", fallback_cluster_shape_enum: " << int(config.fallback_cluster_shape) + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false") + << ", epilogue_fusion_type: " << int(config.epilogue_fusion_type) + << ", swap_ab: " << (config.swap_ab ? "true" : "false"); + } + else + { + out << "tile_config_enum: " << config.getTileConfigAsInt() + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } // clang-format on return out; } From af4036dbe3987223dde3cbb96bc473d6e8da7ff1 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 27 Aug 2025 06:10:59 +0000 Subject: [PATCH 10/17] fix compilation error for sm120 --- .../launchers/fused_moe_gemm_launcher_sm80.inl | 17 ++++++++--------- tests/moe/test_trtllm_cutlass_fused_moe.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl index 52355f34c7..8ecc3fc18b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -34,10 +34,10 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy) { - constexpr auto activation_type = fused_moe_oss::EpilogueRouting(true); + constexpr auto activation_type = fused_moe::EpilogueRouting(true); using GemmType = - fused_moe_oss::Fused_Moe_Kernel_sm80; + fused_moe::Fused_Moe_Kernel_sm80; // make sure GPU has enough resources.. if (kernel_occupancy != nullptr) { @@ -51,7 +51,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe tensorrt_llm::common::check_cuda_error(cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); tensorrt_llm::common::check_cuda_error( - cudaFuncGetAttributes(&attr, fused_moe_oss::run_global)); + cudaFuncGetAttributes(&attr, fused_moe::run_global)); if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { // This should mean that // cudaFuncSetAttribute(cutlass::Kernel, @@ -64,12 +64,11 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe int max_active_blocks = -1; tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, fused_moe_oss::run_global, GemmType::kThreadCount, - smem_size)); + &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); *kernel_occupancy = max_active_blocks; return; } - int occupancy = std::min(2, fused_moe_oss::fused_gemm_maximum_active_blocks()); + int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); int const threadblock_count = multi_processor_count * occupancy; TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); @@ -83,7 +82,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe auto params = GemmType::to_underlying_arguments(args); if (GemmType::kSmemSize >= (48 << 10)) { cudaError_t result = - cudaFuncSetAttribute(fused_moe_oss::run_global, + cudaFuncSetAttribute(fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + @@ -91,7 +90,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe } dim3 grid(params.threadblock_count, 1, 1); dim3 block(GemmType::kThreadCount); - fused_moe_oss::run_global<<>>(params); + fused_moe::run_global<<>>(params); auto result = cudaGetLastError(); TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int)(result)); diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index ecdb3453da..df2c484c35 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -1093,7 +1093,7 @@ def dequant_mxfp4_batches( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11], + torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="MXFP8xMXFP4 is only supported on SM100 and SM110", ) def test_moe_mxfp8_mxfp4( From f109a2b175f52f52c1e196c236f356ca2c428d4b Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 27 Aug 2025 09:09:41 +0000 Subject: [PATCH 11/17] Fix aot failures --- .../cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu | 2 +- .../cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu index 3e3db70369..c1d40e33ac 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -17,7 +17,7 @@ #include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { -#ifdef ENABLE_BF16 +#if defined(ENABLE_BF16) && defined(ENABLE_FP4) template class MoeGemmRunner<__nv_bfloat16, __nv_fp4_e2m1, __nv_bfloat16>; #endif } // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu index 91cd9413b8..ce4b57cc69 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -17,5 +17,7 @@ #include "moe_gemm_template_dispatch.h" namespace tensorrt_llm::kernels::cutlass_kernels { +#if defined(ENABLE_FP4) template class MoeGemmRunner; -} +#endif +} // namespace tensorrt_llm::kernels::cutlass_kernels From a49d1fdcfb35736b420c56c1208abc30347ff54c Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 17 Oct 2025 08:32:12 -0700 Subject: [PATCH 12/17] Add #if defined(ENABLE_FP4) guards --- .../cutlass_kernels/include/moe_gemm_kernels.h | 4 ++++ .../cutlass_kernels/include/moe_kernels.h | 4 ++++ .../moe_gemm/moe_gemm_template_dispatch.h | 10 ++++++++++ .../moe_gemm_template_dispatch_tma_ws.h | 12 ++++++++++++ ...gemm_template_dispatch_tma_ws_mixed_dtype.h | 10 ++++++++++ .../moe_gemm/moe_tma_warp_specialized_traits.h | 18 +++++++++++++++--- 6 files changed, 55 insertions(+), 3 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index aa57faca11..b77efbcac1 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -248,6 +248,7 @@ class MoeGemmRunner { public: MoeGemmRunner(); +#if defined(ENABLE_FP4) #if defined(ENABLE_BF16) static constexpr bool use_wfp4a16 = std::is_same_v && (std::is_same_v || std::is_same_v); @@ -255,6 +256,9 @@ class MoeGemmRunner { static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; #endif +#else + static constexpr bool use_wfp4a16 = false; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) && diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 1e92fdbeba..e278269b97 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -531,6 +531,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { using ScaleBiasType = BackBoneType; using Self = CutlassMoeFCRunner; +#if defined(ENABLE_FP4) #if defined(ENABLE_BF16) static constexpr bool use_wfp4a16 = std::is_same_v && (std::is_same_v || std::is_same_v); @@ -538,6 +539,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; #endif +#else + static constexpr bool use_wfp4a16 = false; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) && diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 6318992b73..6da2ffdeed 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -96,7 +96,9 @@ struct genericMoeGemmKernelLauncher { static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || +#if defined(ENABLE_FP4) cutlass::platform::is_same::value || +#endif cutlass::platform::is_same::value); static_assert(arch::kMinComputeCapability < 90, @@ -737,15 +739,21 @@ void MoeGemmRunner::dispatchToArch( "Hopper configuration provided for non-Hopper architecture"); if (sm_ >= 75 && sm_ < 80) { +#if defined(ENABLE_FP4) if constexpr (!std::is_same_v) { +#endif cutlass_kernels_oss::dispatchMoeGemmToCutlass( inputs, multi_processor_count_); +#if defined(ENABLE_FP4) } else { TLLM_THROW("FP4 data type is not supported on SM < 90"); } +#endif } else if (sm_ >= 80 && sm_ < 90) { +#if defined(ENABLE_FP4) if constexpr (!std::is_same_v) { +#endif if constexpr (use_fp8 || use_w4afp8) { #if defined(ENABLE_FP8) static_assert(!std::is_same_v && @@ -763,9 +771,11 @@ void MoeGemmRunner::dispatchToArch( cutlass::arch::Sm80, EpilogueTag>( inputs, multi_processor_count_); } +#if defined(ENABLE_FP4) } else { TLLM_THROW("FP4 data type is not supported on SM < 90"); } +#endif } else if (sm_ >= 90) { // For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations. if constexpr (use_fp8 && !use_wfp4afp8) { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 4fd4daa8d1..108ba9b81c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -70,8 +70,12 @@ auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilog bool dynamic_cga, bool swap_ab) { auto select_swap_ab = [dynamic_cga, epilogue_schedule](auto swap_ab_t) { auto select_dynamic_cga = [epilogue_schedule](auto dynamic_cga_t) { +#if defined(ENABLE_FP4) constexpr bool is_block_scaled = std::is_same_v || std::is_same_v; +#else + constexpr bool is_block_scaled = false; +#endif if constexpr ((!is_block_scaled || Arch::kMinComputeCapability == 103) && FUSION != EpilogueFusion::FINALIZE) { auto func_map = std::array{ @@ -156,8 +160,12 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( } #endif else { +#if defined(ENABLE_FP4) constexpr static bool is_wfp4afp8 = std::is_same_v && std::is_same_v; +#else + constexpr static bool is_wfp4afp8 = false; +#endif if constexpr (is_wfp4afp8) { TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, @@ -220,8 +228,12 @@ constexpr bool are_tile_shapes_supported_sm100() { constexpr auto TileN = size<1>(CtaShape{}); if constexpr (Arch::kMinComputeCapability == 103) { +#if defined(ENABLE_FP4) return std::is_same_v && std::is_same_v && TileM == 128 && (TileN == 128 || TileN == 256); +#else + return false; +#endif } if constexpr (TileM != 64 && TileM != 128) { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 8c5a5d45e7..eaaedf4258 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -161,10 +161,16 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually // perform the best for mixed type gemms. +#if defined(ENABLE_FP4) constexpr int Ntile = (std::is_same_v) ? 64 : 128; constexpr int Ktile = (std::is_same_v) ? 128 : 128 * PackedScalesNum / sizeof(T); TLLM_CHECK(sizeof(T) == (std::is_same_v) ? 2 : 1); +#else + constexpr int Ntile = 128; + constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); + TLLM_CHECK(sizeof(T) == 1); +#endif using _Ntile = Int; using _Ktile = Int; @@ -246,7 +252,11 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( template size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; +#if defined(ENABLE_FP4) constexpr int Ktile = (std::is_same_v) ? 256 : 512; +#else + constexpr int Ktile = 512; +#endif using _Ktile = Int; #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index 33a9fc8530..5a2d941f42 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -33,11 +33,15 @@ template constexpr bool isValidSM120MOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice +#if defined(ENABLE_FP4) return ((cutlass::platform::is_same::value && cutlass::platform::is_same::value) || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) && cutlass::platform::is_same::value; +#else + return false; +#endif #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -49,10 +53,15 @@ template constexpr bool isValidBlackwellMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice +#if defined(ENABLE_FP4) return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) && cutlass::platform::is_same::value; +#else + return cutlass::platform::is_same::value && + cutlass::platform::is_same::value; +#endif #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled #endif @@ -67,9 +76,12 @@ constexpr bool isValidHopperMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && - cutlass::platform::is_same::value) || - (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value && - !cutlass::platform::is_same::value)) + cutlass::platform::is_same::value) +#ifdef ENABLE_FP4 + || (cutlass::platform::is_same<__nv_fp4_e2m1, WeightType>::value && + !cutlass::platform::is_same::value) +#endif + ) #ifdef ENABLE_FP4 && !cutlass::platform::is_same::value #endif From 13d86647f9c7a0f920360fe2da4d4b76c8f6d160 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 17 Oct 2025 10:02:33 -0700 Subject: [PATCH 13/17] fix: use FLASHINFER_ENABLE_FP8_E8M0 guard for __nv_fp8_e8m0 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 67c1c62a74..61c310f5ed 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -4047,7 +4047,7 @@ CutlassMoeFCRunner:: "WFP4AFP8 expects the scaling factors to be aliased for gemm1 & gemm2"); TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{}; -#ifdef ENABLE_FP8 +#ifdef FLASHINFER_ENABLE_FP8_E8M0 __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf); std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp)); From 4f94bf0c5961c3ad385a2964f1762ed9062c1f80 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 17 Oct 2025 12:43:21 -0700 Subject: [PATCH 14/17] fix build --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 61c310f5ed..698dc1f1d5 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -4047,7 +4047,7 @@ CutlassMoeFCRunner:: "WFP4AFP8 expects the scaling factors to be aliased for gemm1 & gemm2"); TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{}; -#ifdef FLASHINFER_ENABLE_FP8_E8M0 +#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080 __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf); std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp)); From eddb10b2ef6553950c062eb7291cc488dca3e789 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 17 Oct 2025 13:31:16 -0700 Subject: [PATCH 15/17] fix aot errors --- .../fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 698dc1f1d5..b224a09af6 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1358,13 +1358,14 @@ __global__ void expandInputRowsKernel( static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); -#ifdef ENABLE_FP4 constexpr bool is_mxfp8 = std::is_same_v && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX && !PRE_QUANT_AWQ; constexpr bool is_mxfp8_input = is_mxfp8 && std::is_same_v; constexpr bool need_mxfp8_quant = is_mxfp8 && !is_mxfp8_input; + +#ifdef ENABLE_FP4 constexpr bool is_nvfp4 = std::is_same_v && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 && @@ -1372,9 +1373,6 @@ __global__ void expandInputRowsKernel( constexpr bool is_nvfp4_input = is_nvfp4 && std::is_same_v; constexpr bool need_nvfp4_quant = is_nvfp4 && !is_nvfp4_input; #else - constexpr bool is_mxfp8 = false; - constexpr bool is_mxfp8_input = false; - constexpr bool need_mxfp8_quant = false; constexpr bool is_nvfp4 = false; constexpr bool is_nvfp4_input = false; constexpr bool need_nvfp4_quant = false; From 256355625cab0b238eb0140abe3cde6c7053f399 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Tue, 21 Oct 2025 23:42:27 +0000 Subject: [PATCH 16/17] fix stale sm100 configs --- .../cutlass_kernels/cutlass_heuristic.cpp | 119 +++++++++++++++++- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 80b1928437..63bf1235f6 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -264,6 +264,119 @@ bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { #endif } +std::vector get_candidate_configs_sm100_dynamic_cluster_shape( + int sm, CutlassGemmConfig::CandidateConfigTypeParam const config, EpilogueScheduleType schedule, + ClusterShape const dynamic_cluster_shape, ClusterShape const fallback_cluster_shape) { + auto cluster1sm = ClusterShape::ClusterShape_1x1x1; + auto cluster2sm = ClusterShape::ClusterShape_2x1x1; + bool supports_2sm = dynamic_cluster_shape == ClusterShape::Undefined || + std::get<0>(enum_to_shape_tuple(dynamic_cluster_shape)) % 2 == 0; + + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + if (sm == 100) { + if (schedule != EpilogueScheduleType::TMA) return {}; + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + } + + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back( + CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, + schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + if (supports_2sm) { + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + candidate_configs.push_back(CutlassGemmConfig{ + CutlassTileConfigSM100::CtaShape128x256x128B, MainloopScheduleType::AUTO, schedule, + cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm}); + } + return candidate_configs; + } + + std::vector> tile_configs{ + {CutlassTileConfigSM100::CtaShape128x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x64x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x32x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x128x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape64x256x128B, cluster1sm}, + {CutlassTileConfigSM100::CtaShape128x64x128B, cluster1sm}, + }; + + if (supports_2sm) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x256x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape64x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x64x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x128x128B, cluster2sm}); + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x256x128B, cluster2sm}); + } + + if (config & CutlassGemmConfig::FP8_ONLY) { + tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, cluster1sm}); + // TODO: re-enable when handled by the MoE GEMM dispatch + // tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, + // ClusterShape::ClusterShape_1x1x1 }); + } + + for (auto [tile, cluster] : tile_configs) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, schedule, + cluster, dynamic_cluster_shape, fallback_cluster_shape, + sm}; + candidate_configs.push_back(config); + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100( + CutlassGemmConfig::CandidateConfigTypeParam const config, int sm) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::TMA, + ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, + ClusterShape::Undefined, sm}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + for (auto schedule : {EpilogueScheduleType::TMA, EpilogueScheduleType::NO_SMEM}) { + // TODO The tactic profiling is a bit long with all of these shapes enabled + // Shape 4x4x1 shapes do not seem to give better performance in the cases I tested so we + // disable it here + auto cluster_shapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_4x1x1, + ClusterShape::ClusterShape_4x2x1 /*, ClusterShape::ClusterShape_4x4x1*/}; + for (auto cluster_shape : cluster_shapes) { + auto fallback_cluster_shape = cluster_shape == ClusterShape::ClusterShape_1x1x1 + ? ClusterShape::ClusterShape_1x1x1 + : ClusterShape::ClusterShape_2x1x1; + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, cluster_shape, fallback_cluster_shape); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + + auto configs = get_candidate_configs_sm100_dynamic_cluster_shape( + sm, config, schedule, ClusterShape::Undefined, ClusterShape::Undefined); + candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end()); + } + return candidate_configs; + } else { + TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif +} + std::vector get_candidate_configs_sm90( CutlassGemmConfig::CandidateConfigTypeParam const config) { auto tiles = get_candidate_tiles_sm90(config); @@ -330,7 +443,7 @@ std::vector get_candidate_configs_sm90( return candidate_configs; } -std::vector get_candidate_configs_sm100( +/*std::vector get_candidate_configs_sm100( CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD // Fast build disables all configs except this one for SM100 @@ -413,7 +526,7 @@ std::vector get_candidate_configs_sm100( TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); } #endif -} +}*/ std::vector get_candidate_configs_sm110( CutlassGemmConfig::CandidateConfigTypeParam const config) { @@ -538,7 +651,7 @@ std::vector get_candidate_configs( return get_candidate_configs_sm110(config_type_param); } if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { - return get_candidate_configs_sm100(config_type_param); + return get_candidate_configs_sm100(config_type_param, sm); } if (sm >= 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { return get_candidate_configs_sm120(config_type_param); From da54367259316af6ff53b63d7d89d6037e89939a Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 22 Oct 2025 01:05:41 +0000 Subject: [PATCH 17/17] debug.. --- .../moe_gemm/moe_gemm_template_dispatch_tma_ws.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 108ba9b81c..a547c30b4c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -189,6 +189,14 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( auto cluster_shape_cute_fallback = cute::Shape{ std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; + // HACK debug the gemm_config used to produce selected_func + std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version + << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) + << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) + << ", dynamic_cluster_shape=" << static_cast(gemm_config.dynamic_cluster_shape) + << ", fallback_cluster_shape=" + << static_cast(gemm_config.fallback_cluster_shape) << std::endl; + auto selected_func = getDispatchFunctionForSM100(