diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index 1738c121f1..1be2e43145 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -41,10 +41,13 @@ #include "collective/fmha_common.hpp" #include "collective/fmha_fusion.hpp" #include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" +#include "cutlass/detail/dependent_false.hpp" namespace cutlass::fmha::collective { using namespace cute; +using namespace constexpr_type_map; +using namespace constexpr_constexpr_map; template< class Element_, @@ -85,10 +88,32 @@ struct Sm100FmhaGenMainloopWarpspecialized { using StrideO = decltype(replace<0>(StrideO_{}, 0)); using Mask = Mask_; + using TileM = decltype(get<0>(TileShape{})); // seq Q dim + static_assert(TileM::value == 64 || TileM::value == 128, "Only expecting TileM to be 64 or 128"); static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - // local changes - static constexpr int StageCountKV = StageCountQ * (sizeof(Element) == 1 ? 11 : 5) ; - + // Choose StageCountKV based on: + // - Tile shape on the M (i.e., Query) dimension + // - Element size + using StageCountKVSelector = kValTyMap< + void, + kValTyPair<64, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 12>, // fp8 + kValValPair<2, 6> // bf16/fp16 + >>, + kValTyPair<128, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 11>, // fp8 + kValValPair<2, 5> // bf16/fp16 + >> + >; + static constexpr int StageCountKV = StageCountQ * + StageCountKVSelector:: + template query:: + template query; + using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -129,28 +154,52 @@ struct Sm100FmhaGenMainloopWarpspecialized { }; }; + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1, + kIdxStatsEnd = 2 + }; + + // Each storage reserves kTMEM_V_COLUMNS for row max/sum stats + // TileM=64 uses 16dp64b --> two threads processing a row + // TileM=128 uses 32dp32b --> one thread processing a row + using kTMEM_V_COLUMNS = typename kValTyMap>, + kValTyPair<128, Int> + >::template query; + + // TMEM column allocation, offset will be used to calc the lower 16-bit of tmem addresses. + // TMEM row/lane dimension is for the Q dim. enum class TmemAllocation : uint32_t { - kSizeS = 128, - kSizeO = 128, - kSizeP = 32, + kSizeS = get<1>(TileShapeQK{}), // i.e., KV dim in a tile + kSizeO = get<2>(TileShapeQK{}), // i.e., head dim + // carve kSizeS to two parts: first 1/4 for V0/V1 stats storage; the rest for P0/P1 + // 1/4 is wasting some storage here but there seems to be column-wise address alignment requirements not found in spec. + // Since there is enough storage left for P0/P1, chose to not debug alignment issues. + kSizeV = kSizeS / 2, + // P will be casted to the same type as V + kSizeP = kSizeS * sizeof(Element) / sizeof(float), S0 = 0, S1 = S0 + kSizeS, V0 = S0, // stats storage from softmax to correction V1 = S1, - P0 = S0 + kSizeP, - P1 = S1 + kSizeP, + P0 = V0 + kSizeV, + P1 = V1 + kSizeV, O0 = S1 + kSizeS, O1 = O0 + kSizeO, kEnd = O1 + kSizeO }; - - // indices for V0 / V1 - enum : int { - kIdxOldRowMax = 0, - kIdxNewRowMax = 1, - kIdxFinalRowSum = 0, - kIdxFinalRowMax = 1 - }; + static_assert(static_cast(TmemAllocation::kEnd) <= 512, "Exceeds TMEM 512 columns"); + static_assert( + static_cast(TmemAllocation::kSizeV) + static_cast(TmemAllocation::kSizeP) <= + static_cast(TmemAllocation::kSizeS), + "Not enough storage to carve V and P out of S"); + static_assert( + static_cast(kTMEM_V_COLUMNS::value) <= static_cast(TmemAllocation::kSizeV), + "Not enough storage reserved for V"); // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineUmmaConsumerAsync< @@ -533,41 +582,41 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); - Tensor tStS_v = - tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - tStS_v.data() = + Tensor tStS_v = + tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); - Tensor tScS_v = - tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + Tensor tScS_v = + tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); - auto tilePlikeFP32 = _32{}; // 32 for FP32 - // size<1>(TileShapeQK{}) / Int{} * Int{}; - - // tilePlikeFP32 = 64/4*2 = 32 for BF16 - // Preserve hierarchical structure: ((16, 4), 32) = 16*4*32 = 2048 elements + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose( - make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + make_layout(make_shape(TileM{}, tilePlikeFP32))); tStS_P.data() = warp_uniform( - uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose( - make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + make_layout(make_shape(TileM{}, tilePlikeFP32))); + + // needed number of cols to load from tmem to reg + constexpr int kConversionsPerStep = 2; + constexpr int kTmemLoadNcells = cute::min(32, size<1>(TileShapeQK{}) / kConversionsPerStep); + constexpr int kTmemStoreNcells = kTmemLoadNcells * sizeof_bits_v / sizeof_bits_v; - // Select TMEM operation based on K dimension (number of columns) - // For K=64: 64 rows × 64 cols = 4,096 elements → use 16dp32b4x - // For K=128: 64 rows × 128 cols = 8,192 elements → use 16dp32b8x - using TMEM_LOAD = conditional_t< - size<1>(TileShapeQK{}) == _64{}, - SM100_TMEM_LOAD_16dp32b16x, // For K=64: 4,096 elements - SM100_TMEM_LOAD_16dp32b8x>; // For K=128: 8,192 elements + using TMEM_LOAD_1xOP = typename kValTyMap, + // Each thread owns a single row + kValTyPair<128, SM100_TMEM_LOAD_32dp32b1x> + >::template query; + using TMEM_STORE_1xOP = decltype(TMEM::tmem_load_to_store(TMEM_LOAD_1xOP{})); + using TMEM_LOAD = decltype(TMEM::op_repeater()); + using TMEM_STORE = decltype(TMEM::op_repeater()); - using TMEM_STORE = conditional_t< - size<1>(TileShapeQK{}) == _64{}, - SM100_TMEM_STORE_16dp32b8x, // For K=64, BF16: 2,048 elements - SM100_TMEM_STORE_16dp32b8x>; + using TMEM_STORE_V = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; - // TMEM_STORE_V: Store row statistics (old_max, new_max) for online softmax - // correction Always 64 rows × 2 cols = 128 FP32 elements - using TMEM_STORE_V = SM100_TMEM_STORE_16dp32b2x; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); @@ -616,12 +665,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); + if constexpr (TileM{} == 64) { + ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); + row_max = ::fmax(row_max, shuffled_row_max); + } } - ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); - row_max = ::fmax(row_max, shuffled_row_max); ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + static_assert(size(tTMEM_STOREVrS) == 2); tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); @@ -639,48 +691,64 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); - constexpr int kConversionsPerStep = 2; + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; - const int kReleasePipeCount = 10; // must be multiple of 2 order_s.wait(); + static_assert(kReleasePipeCount % kConversionsPerStep == 0); + static_assert(kConversionsPerStep == 2); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { - float2 in = make_float2( - tTMEM_LOADrS(i + 0), - tTMEM_LOADrS(i + 1) - ); - float2 out; - cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); - tTMEM_LOADrS(i + 0) = out.x; - tTMEM_LOADrS(i + 1) = out.y; - - tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); - tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); - - Array in_conv; + { CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kConversionsPerStep; j++) { - in_conv[j] = tTMEM_LOADrS(i + j); - } - tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + for (int i = 0; i < size(tTMEM_LOADrS); i += kConversionsPerStep) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); - - if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { - order_s.arrive(); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + if constexpr (TileM::value == 128) { + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + //this prevents register spills in fp16 + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } } - - } + } // tmem_store(reg_S8) -> op_P - // CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); - // CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + if constexpr (TileM::value == 128) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + } else { + copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); + } cutlass::arch::fence_view_async_tmem_store(); @@ -722,10 +790,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_sum = local_row_sum; if (final_call) { + if constexpr (TileM{} == 64) { + // Sync threads 0 and 16 to get the sum of row_sum between them + row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); + } + // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); - // Sync threads 0 and 16 to get the sum of row_sum between them - row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; @@ -815,21 +887,31 @@ struct Sm100FmhaGenMainloopWarpspecialized { const int kCorrectionTileSize = 32 / sizeof(ElementOut); // TODO: load all values - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + // Choose TMEM OP based on + // - TileM shape + // - kCorrectionTileSize + using TMEM_LOAD_OPMAP = kValTyMap + > + >, + kValTyPair<128, + kValTyMap + >> // 4x32 threads with 64 cols of 32b elem + >; + using TMEM_LOAD = typename TMEM_LOAD_OPMAP::template query::template query; typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - - Tensor tOtO_i = tOtO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(TileM{}, Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -895,13 +977,13 @@ struct Sm100FmhaGenMainloopWarpspecialized { tCd(j) = convert.convert(tCs(j)); } - Tensor tSMgO_i = recast(tTMEM_LOADgO_i); - Tensor tSMrO_i = recast(tSMrO); + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); - // could use masking do this right for smaller D - if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); - } + } } } @@ -917,18 +999,22 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_16dp32b16x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_16dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; + using TMEM_STORE = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - - Tensor tOtO_i = tOtO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout( - make_shape(make_shape(_16{}, _4{}), Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; @@ -1009,13 +1095,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - - Tensor tStS_v = tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - Tensor tScS_v = tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - using TMEM_LOAD_V = - SM100_TMEM_LOAD_16dp32b2x; // 4x32 threads with 2 cols of 32b elem + Tensor tStS_v = tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + using TMEM_LOAD_V = + typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); @@ -1043,6 +1131,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + static_assert(size(tTMEM_LOADVrS) == 2); // read row_wise new global max copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 42057aeafe..a3a51d15b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -687,7 +687,7 @@ def _execute_cutlass_blackwell_attn_varlen( sm_scale, num_groups, ) - for dtype in [torch.bfloat16] + for dtype in [torch.bfloat16, torch.float8_e4m3fn] for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True, False]