Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE()
return forceDeterministic;
}

bool getEnvMOEDisableFinalizeFusion()
{
static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION");
return moeDisableFinalizeFusion;
}

bool getEnvForceDeterministicAttention()
{
static bool const forceDeterministic
Expand Down
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ bool getEnvForceDeterministic();
// Force deterministic behavior for MoE plugin.
bool getEnvForceDeterministicMOE();

// Disable finalize fusion in MoE plugin
bool getEnvMOEDisableFinalizeFusion();

// Force deterministic behavior for attention plugin.
bool getEnvForceDeterministicAttention();

Expand Down

This file was deleted.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100
CtaShape128x256x128B,
CtaShape128x128x256B,
CtaShape128x256x256B,

// M=256
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B,
};

enum class CutlassTileConfigSM120
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"

namespace tensorrt_llm::cutlass_extensions
namespace cutlass::util
{

/// Function object that applies an index to its argument
Expand Down Expand Up @@ -81,7 +81,7 @@ struct CustomStride
template <class Div>
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div));
}

// Circumvent the requirement on make_layout that shape and stride are integral
Expand Down Expand Up @@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
}
} // namespace tensorrt_llm::cutlass_extensions
} // namespace cutlass::util

namespace cute
{
Expand Down
92 changes: 41 additions & 51 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,72 +377,62 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
std::vector<CutlassGemmConfig> candidate_configs;
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
if (config & CutlassGemmConfig::FP4_ONLY)
{
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
return candidate_configs;
}

for (int cluster_m = 1; cluster_m <= 2; cluster_m++)
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1},
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1},
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1},
};

if (config & CutlassGemmConfig::FP8_ONLY)
{
bool Is2SM = cluster_m == 2;
for (int cluster_n = 1; cluster_n <= 2; cluster_n++)
{
std::vector base = {// M=128
CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B};

if (Is2SM)
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
}

std::vector twosm = {// M=256
CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B};
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
}
else
{
if (cluster_n == 1)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
if ((config & CutlassGemmConfig::FP8_ONLY) != 0)
{
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
}
}

std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B};
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
}
tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1});
// TODO: re-enable when handled by the MoE GEMM dispatch
// tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 });
}

constexpr std::array cluster_shapes
= {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
for (auto tile : base)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
}
for (auto [tile, cluster] : tile_configs)
{
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
candidate_configs.push_back(config);
}
return candidate_configs;
}
Expand Down
44 changes: 13 additions & 31 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@

namespace tensorrt_llm::kernels::cutlass_kernels
{
template <class T>
constexpr auto transpose_stride(T const& t)
{
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
}

template <typename AType, typename BType, typename BScaleType, typename OType>
struct GroupedGemmInput
Expand Down Expand Up @@ -72,8 +67,6 @@ struct GroupedGemmInput

struct TmaWarpSpecializedGroupedGemmInput
{
template <class T>
using TransposeStride = decltype(transpose_stride<T>(T{}));
template <class Tag>
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
Expand All @@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand

constexpr static int NVFP4BlockScaleVectorSize = 16;
constexpr static int MXFPXBlockScaleVectorSize = 32;
Expand Down Expand Up @@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput
using StrideB
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;

#ifdef ENABLE_FP8
template <class T>
Expand All @@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput
StrideC* stride_c = nullptr;
void const** ptr_c = nullptr;

struct DefaultEpilogue
{
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;

StrideD* stride_d = nullptr;
void** ptr_d = nullptr;
};
// D is used in all cases except fused finalize
StrideD* stride_d = nullptr;
void** ptr_d = nullptr;

struct FusedFinalizeEpilogue
{
using StrideFinalOutput = DefaultEpilogue::StrideD;
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;

void* ptr_final_output = nullptr;
StrideFinalOutput stride_final_output{};

void const* ptr_bias = nullptr;
StrideBias stride_bias{};

float const* ptr_router_scales = nullptr;
StrideRouterScales stride_router_scales{};
void const** ptr_bias = nullptr;
float const** ptr_router_scales = nullptr;

int64_t const* ptr_expert_first_token_offset = nullptr;
int const* ptr_source_token_index = nullptr;
int const** ptr_source_token_index = nullptr;
int num_rows_in_final_output = 0;

size_t num_rows_in_final_output = 0;
bool use_reduction = true;
};

DefaultEpilogue default_epilogue;
FusedFinalizeEpilogue fused_finalize_epilogue;

enum class EpilogueFusion
Expand Down Expand Up @@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;

static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);

static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);

Expand All @@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
return stride_a != nullptr && ptr_a != nullptr;
}

void setFinalizeFusionParams(void* final_output, float const* router_scales,
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
int num_output_tokens);
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);

std::string toString() const;
};
Expand Down
20 changes: 12 additions & 8 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ class CutlassMoeFCRunnerInterface
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
= 0;

virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
Expand All @@ -512,13 +513,13 @@ class CutlassMoeFCRunnerInterface
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;

bool is_profiler = false;
bool use_deterministic_hopper_reduce_ = false;
bool use_fused_finalize_ = true;
};

// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
// Avoid making several duplicates of this class.
template <typename T, /*The type used for activations*/
template <typename T, /* The type used for activations */
typename WeightType, /* The type for the MoE weights */
typename OutputType = T, /* The type for the MoE final output */
typename InputType = T, /* The type for the MoE input */
Expand Down Expand Up @@ -709,7 +710,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
{
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
Expand All @@ -718,7 +720,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
stream);
}

std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
Expand Down Expand Up @@ -760,7 +763,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
cudaStream_t stream);
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
Expand Down Expand Up @@ -790,8 +794,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface

bool mayHaveFinalizeFused() const
{
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
&& !use_deterministic_hopper_reduce_ && !use_w4_groupwise;
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
&& !use_w4_groupwise;
}

// TODO: This should eventually take the quant params to give more flexibility
Expand Down
Loading