Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions include/flashinfer/gemm/fp4_gemm_template_sm120.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const*
std::string(cutlass::cutlassGetStatusString(initStatus)));
}

// Disable PDL since GDC flag is not set
auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/false);
// Enable PDL β€” GDC flag (CUTLASS_ENABLE_GDC_FOR_SM100) is set at compile time
auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true);
if (runStatus != cutlass::Status::kSuccess) {
throw std::runtime_error(std::string("[FP4 gemm Runner") + scheduler_name + "] " +
"Failed to run cutlass FP4 gemm on sm120/sm121. Error: " +
Expand Down Expand Up @@ -240,39 +240,39 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const*
using ElementCompute = float; \
using ElementAccumulator = float; \
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; \
using EpilogueTileType = \
cutlass::epilogue::collective::EpilogueTileAuto; /* Always use auto for SM100 */ \
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; \
using MainloopSchedule = cutlass::gemm::collective::KernelScheduleAuto; \
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; \
using FusionOperation = \
cutlass::epilogue::fusion::LinearCombination<OutElementType, float, void, float>; \
using ThreadBlockShape = cute::Shape<cute::Int<CTA_M_>, cute::Int<CTA_N_>, cute::Int<CTA_K_>>; \
/* Epilogue: explicit TmaWarpSpecialized schedule (matches TRT-LLM SM120 pattern) */ \
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< \
Arch, OperatorClass, ThreadBlockShape, ClusterShape, EpilogueTileType, ElementAccumulator, \
ElementCompute, ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \
EpilogueSchedule, \
cutlass::epilogue::fusion::LinearCombination<OutElementType, float, void, \
float>>::CollectiveOp; \
Arch, cutlass::arch::OpClassTensorOp, ThreadBlockShape, ClusterShape, \
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, \
ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \
cutlass::epilogue::TmaWarpSpecialized, FusionOperation>::CollectiveOp; \
\
/* SM120/SM121 BlockScaled - Use nv_float4_t without tuples like example 79 */ \
/* Use fixed 2 stages for SM120/SM121 to meet minimum requirement */ \
/* Dynamic stage carveout adapts pipeline depth to available smem after epilogue */ \
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< \
Arch, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, \
ElementAccumulator, ThreadBlockShape, ClusterShape, \
cutlass::gemm::collective::StageCount<2>, /* Fixed 2 stages for SM120 */ \
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; \
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( \
sizeof(typename CollectiveEpilogue::SharedStorage))>, \
cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; \
Comment thread
bkryu marked this conversation as resolved.
\
/* Two scheduler options for different workloads */ \
/* See: https://github.com/NVIDIA/cutlass/blob/main/examples/79_blackwell_geforce_gemm */ \
using TileSchedulerTag = cutlass::gemm::StaticPersistentScheduler; \
\
/* Option 1: Default scheduler (void) - Data Parallel, good for regular shapes */ \
using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal< \
cute::Shape<int, int, int, int>, /* Indicates ProblemShape */ \
CollectiveMainloop, CollectiveEpilogue, void>; /* Default DP scheduler */ \
/* Option 1: Persistent scheduler - reduced launch overhead, good default */ \
using GemmKernelDefault = \
cutlass::gemm::kernel::GemmUniversal<cute::Shape<int, int, int, int>, CollectiveMainloop, \
CollectiveEpilogue, TileSchedulerTag>; \
Comment thread
coderabbitai[bot] marked this conversation as resolved.
\
/* Option 2: StreamK scheduler - better load balancing for small M/N, large K */ \
using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal< \
cute::Shape<int, int, int, int>, /* Indicates ProblemShape */ \
CollectiveMainloop, CollectiveEpilogue, \
cutlass::gemm::StreamKScheduler>; /* StreamK scheduler */ \
using GemmKernelStreamK = \
cutlass::gemm::kernel::GemmUniversal<cute::Shape<int, int, int, int>, CollectiveMainloop, \
CollectiveEpilogue, cutlass::gemm::StreamKScheduler>; \
\
using GemmDefault = typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernelDefault>; \
using GemmStreamK = typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernelStreamK>; \
Expand Down
Loading