diff --git a/include/flashinfer/gemm/fp4_gemm_template_sm120.h b/include/flashinfer/gemm/fp4_gemm_template_sm120.h index 37b3b9d9ad..8abefc7001 100644 --- a/include/flashinfer/gemm/fp4_gemm_template_sm120.h +++ b/include/flashinfer/gemm/fp4_gemm_template_sm120.h @@ -185,8 +185,8 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const* std::string(cutlass::cutlassGetStatusString(initStatus))); } - // Disable PDL since GDC flag is not set - auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/false); + // Enable PDL — GDC flag (CUTLASS_ENABLE_GDC_FOR_SM100) is set at compile time + auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); if (runStatus != cutlass::Status::kSuccess) { throw std::runtime_error(std::string("[FP4 gemm Runner") + scheduler_name + "] " + "Failed to run cutlass FP4 gemm on sm120/sm121. Error: " + @@ -240,39 +240,39 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const* using ElementCompute = float; \ using ElementAccumulator = float; \ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; \ - using EpilogueTileType = \ - cutlass::epilogue::collective::EpilogueTileAuto; /* Always use auto for SM100 */ \ - using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; \ - using MainloopSchedule = cutlass::gemm::collective::KernelScheduleAuto; \ + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; \ + using FusionOperation = \ + cutlass::epilogue::fusion::LinearCombination; \ using ThreadBlockShape = cute::Shape, cute::Int, cute::Int>; \ + /* Epilogue: explicit TmaWarpSpecialized schedule (matches TRT-LLM SM120 pattern) */ \ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< \ - Arch, OperatorClass, ThreadBlockShape, ClusterShape, EpilogueTileType, ElementAccumulator, \ - ElementCompute, ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \ - EpilogueSchedule, \ - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; \ + Arch, cutlass::arch::OpClassTensorOp, ThreadBlockShape, ClusterShape, \ + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, \ + ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \ + cutlass::epilogue::TmaWarpSpecialized, FusionOperation>::CollectiveOp; \ \ /* SM120/SM121 BlockScaled - Use nv_float4_t without tuples like example 79 */ \ - /* Use fixed 2 stages for SM120/SM121 to meet minimum requirement */ \ + /* Dynamic stage carveout adapts pipeline depth to available smem after epilogue */ \ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< \ Arch, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, \ ElementAccumulator, ThreadBlockShape, ClusterShape, \ - cutlass::gemm::collective::StageCount<2>, /* Fixed 2 stages for SM120 */ \ - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; \ + cutlass::gemm::collective::StageCountAutoCarveout( \ + sizeof(typename CollectiveEpilogue::SharedStorage))>, \ + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; \ \ /* Two scheduler options for different workloads */ \ /* See: https://github.com/NVIDIA/cutlass/blob/main/examples/79_blackwell_geforce_gemm */ \ + using TileSchedulerTag = cutlass::gemm::StaticPersistentScheduler; \ \ - /* Option 1: Default scheduler (void) - Data Parallel, good for regular shapes */ \ - using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal< \ - cute::Shape, /* Indicates ProblemShape */ \ - CollectiveMainloop, CollectiveEpilogue, void>; /* Default DP scheduler */ \ + /* Option 1: Persistent scheduler - reduced launch overhead, good default */ \ + using GemmKernelDefault = \ + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, \ + CollectiveEpilogue, TileSchedulerTag>; \ \ /* Option 2: StreamK scheduler - better load balancing for small M/N, large K */ \ - using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal< \ - cute::Shape, /* Indicates ProblemShape */ \ - CollectiveMainloop, CollectiveEpilogue, \ - cutlass::gemm::StreamKScheduler>; /* StreamK scheduler */ \ + using GemmKernelStreamK = \ + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, \ + CollectiveEpilogue, cutlass::gemm::StreamKScheduler>; \ \ using GemmDefault = typename cutlass::gemm::device::GemmUniversalAdapter; \ using GemmStreamK = typename cutlass::gemm::device::GemmUniversalAdapter; \