Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def gen_gemm_sm120_module_cutlass_fp4() -> JitSpec:
with open(jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass_sm120.jinja") as f:
kernel_inst_templ = jinja2.Template(f.read())
dtype_list = ["__nv_bfloat16", "half"]
# SM120/121 uses only 128x128x128 tile configuration with implied 1x1x1 cluster shape
# SM120/121 tile configurations with implied 1x1x1 cluster shape
cta_m_n_k_list = [
(128, 128, 128),
(128, 128, 256),
(256, 128, 128),
]
for cta_m, cta_n, cta_k in cta_m_n_k_list:
for dtype in dtype_list:
Expand Down
12 changes: 10 additions & 2 deletions include/flashinfer/gemm/cutlass_gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ struct CutlassGemmConfig {
bool enableCudaKernel = false;
int sm_version = 80; // Use 80 as a catch all for <90
bool is_tma_warp_specialized = false;
bool use_stream_k = false; // SM120: false = DP scheduler (default), true = StreamK scheduler

CutlassGemmConfig() = default;

Expand Down Expand Up @@ -352,15 +353,18 @@ struct CutlassGemmConfig {
sm_version(100),
is_tma_warp_specialized(true) {}

// SM120 constructor with optional StreamK scheduler
// use_stream_k: false = DP scheduler (default), true = StreamK scheduler (auto heuristic)
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120,
MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule,
ClusterShape cluster_shape)
ClusterShape cluster_shape, bool use_stream_k = false)
: tile_config_sm120(tile_config_sm120),
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape),
sm_version(120),
is_tma_warp_specialized(true) {}
is_tma_warp_specialized(true),
use_stream_k(use_stream_k) {}

int getTileConfigAsInt() const {
if (sm_version == 120) return (int)tile_config_sm120;
Expand All @@ -383,6 +387,10 @@ struct CutlassGemmConfig {
<< "\n\tmainloop sched: " << (int)mainloop_schedule
<< "\n\tepi sched: " << (int)epilogue_schedule
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
// SM120 specific: StreamK scheduler option
if (sm_version == 120) {
tactic << "\n\tscheduler: " << (use_stream_k ? "StreamK (auto heuristic)" : "DP (default)");
}
} else if (tile_config_sm80 != flashinfer::gemm::CutlassTileConfig::ChooseWithHeuristic) {
assert(sm_version < 90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
Expand Down
83 changes: 58 additions & 25 deletions include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace flashinfer {
namespace gemm {
using namespace cute;

template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
// UseStreamK: false = DP scheduler (default), true = StreamK scheduler
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_, bool UseStreamK = false>
size_t dispatchNVFP4xNVFP4GemmClusterShapeSm120(T* D, void const* A, void const* B,
void const* input_sf, void const* weight_sf,
float const* global_sf, int m, int n, int k,
Expand All @@ -53,10 +54,17 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm120(T* D, void const* A, void const*
cudaStream_t stream, int* occupancy = nullptr) {
// For SM120/SM121, only support 1x1x1 cluster shape
// Always use 1x1x1 cluster shape regardless of gemmConfig.cluster_shape
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>,
cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
if constexpr (UseStreamK) {
return genericFp4GemmKernelLauncherStreamK<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>,
cute::Int<1>, cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
} else {
return genericFp4GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<1>, cute::Int<1>,
cute::Int<1>, _1SM>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
}
}

/*!
Expand All @@ -78,40 +86,50 @@ size_t dispatchNVFP4xNVFP4GemmClusterShapeSm120(T* D, void const* A, void const*
* \param occupancy Optional pointer to store kernel occupancy
* \return Size of workspace required in bytes
*/
// Helper macro to dispatch tile config with scheduler selection
#define DISPATCH_TILE_CONFIG(CTA_M, CTA_N, CTA_K, USE_STREAMK) \
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<CTA_M>, cute::Int<CTA_N>, \
cute::Int<CTA_K>, USE_STREAMK>( \
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, \
workspaceBytes, stream, occupancy)

// Helper macro to dispatch with scheduler check
#define DISPATCH_WITH_SCHEDULER(CTA_M, CTA_N, CTA_K) \
if (gemmConfig.use_stream_k) { \
DISPATCH_TILE_CONFIG(CTA_M, CTA_N, CTA_K, true); \
} else { \
DISPATCH_TILE_CONFIG(CTA_M, CTA_N, CTA_K, false); \
}

template <typename T>
size_t dispatchNVFP4xNVFP4GemmCTAShapeSm120(T* D, void const* A, void const* B,
void const* input_sf, void const* weight_sf,
float const* global_sf, int m, int n, int k,
int batch_count, CutlassGemmConfig gemmConfig,
char* workspace, const size_t workspaceBytes,
cudaStream_t stream, int* occupancy = nullptr) {
// For SM120/SM121, we only support 128x128x128 tile configuration
// Check the SM120 tile config and dispatch accordingly
// Dispatch based on tile config and scheduler type
switch (gemmConfig.tile_config_sm120) {
case CutlassTileConfigSM120::CtaShape128x128x128B:
// Always use 1x1x1 cluster shape for SM120
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
Comment on lines 113 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The break statements are missing between these case labels. While the DISPATCH_WITH_SCHEDULER macro expands to a return statement, which prevents functional issues from fallthrough, this structure is confusing and can be flagged by compilers with -Wimplicit-fallthrough. Adding break statements makes the control flow explicit and improves readability and maintainability, even if the break is currently unreachable.

Suggested change
case CutlassTileConfigSM120::CtaShape128x128x128B:
// Always use 1x1x1 cluster shape for SM120
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
case CutlassTileConfigSM120::CtaShape128x128x128B:
DISPATCH_WITH_SCHEDULER(128, 128, 128);
break;
case CutlassTileConfigSM120::CtaShape128x128x256B:
DISPATCH_WITH_SCHEDULER(128, 128, 256);
break;
case CutlassTileConfigSM120::CtaShape256x128x128B:
DISPATCH_WITH_SCHEDULER(256, 128, 128);
break;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each DISPATCH macro will return, so I suppose it's fine.

case CutlassTileConfigSM120::Undefined:
throw std::runtime_error("[Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined.");
break;
case CutlassTileConfigSM120::ChooseWithHeuristic:
throw std::runtime_error(
"[Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been set by "
"heuristic.");
break;
default:
// For any other SM120 tile configs that we don't support yet, fall back to 128x128x128
return dispatchNVFP4xNVFP4GemmClusterShapeSm120<T, cute::Int<128>, cute::Int<128>,
cute::Int<128>>(
D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace,
workspaceBytes, stream, occupancy);
break;
DISPATCH_WITH_SCHEDULER(128, 128, 128); // Fallback
}
}

#undef DISPATCH_WITH_SCHEDULER
#undef DISPATCH_TILE_CONFIG

template <typename T, FP4GemmType fp4GemmType>
CutlassFp4GemmRunner<T, fp4GemmType>::CutlassFp4GemmRunner() {}

Expand Down Expand Up @@ -150,11 +168,26 @@ template <typename T, FP4GemmType fp4GemmType>
std::vector<CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getConfigs() const {
std::vector<CutlassGemmConfig> candidateConfigs;

// For SM120/SM121, only support 128x128x128 tile with 1x1x1 cluster shape
CutlassGemmConfig config(CutlassTileConfigSM120::CtaShape128x128x128B, MainloopScheduleType::AUTO,
EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidateConfigs.push_back(config);
// All supported tile configurations for SM120
std::vector<CutlassTileConfigSM120> tilesSm120 = {
CutlassTileConfigSM120::CtaShape128x128x128B,
CutlassTileConfigSM120::CtaShape128x128x256B,
CutlassTileConfigSM120::CtaShape256x128x128B,
};

// SM120 only supports 1x1x1 cluster shape
ClusterShape clusterShape = ClusterShape::ClusterShape_1x1x1;

// Generate configs for both DP and StreamK schedulers
for (auto const& tile_config : tilesSm120) {
// Default DP scheduler (use_stream_k = false)
candidateConfigs.push_back(CutlassGemmConfig(tile_config, MainloopScheduleType::AUTO,
EpilogueScheduleType::AUTO, clusterShape, false));

// StreamK scheduler (use_stream_k = true) - better for small M/N, large K
candidateConfigs.push_back(CutlassGemmConfig(tile_config, MainloopScheduleType::AUTO,
EpilogueScheduleType::AUTO, clusterShape, true));
}
return candidateConfigs;
}

Expand Down
Loading
Loading