diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 5b618b6526..5227510217 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -301,7 +301,7 @@ at::Tensor dispatch_fmha_gen_fwd( return DISPATCH_ELEMENT_TYPE(q.scalar_type(), Element, [&] { return DISPATCH_KERNEL_TYPE(static_cast(kernel_type), KType, [&] { - GenRunner, Shape<_1, _1, _1>> + GenRunner, Shape<_1, _1, _1>> runner; return runner.fmha_fwd(q, k, v, seqlen_kv, batch_idx); }); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..1e0ea6d449 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts( TiledMMA, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, TAs...>, TMs...>) { return TiledMMA, + a_major, + b_major, + a_neg, + b_neg>, TAs...>, TMs...>) { return TiledMMA +struct kValTyPair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename kValTyMap::template query>; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == LastMapping::key, + typename LastMapping::valueT, + Default>; +}; + +} // namespace constexpr_type_map + +namespace constexpr_constexpr_map { + +template +struct kValValPair { + static constexpr auto key = keyVal; + static constexpr auto value = valueVal; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + static_assert( + (std::is_same_v && ...), + "Map value type mismatch"); + template + static constexpr decltype(FirstMapping::value) query = + (QueryKey == FirstMapping::key) + ? FirstMapping::value + : kValValMap::template query; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + template + static constexpr decltype(LastMapping::value) query = + (QueryKey == LastMapping::key) ? LastMapping::value : Default; +}; + +} // namespace constexpr_constexpr_map diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index e8e9aafceb..1738c121f1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -526,35 +526,48 @@ struct Sm100FmhaGenMainloopWarpspecialized { PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { - - Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + Tensor tScS = + typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - - auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; - Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); - Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - - // local changes - // Each thread owns a single row + Tensor tStS_v = + tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + tStS_v.data() = + uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = + tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + + auto tilePlikeFP32 = _32{}; // 32 for FP32 + // size<1>(TileShapeQK{}) / Int{} * Int{}; + + // tilePlikeFP32 = 64/4*2 = 32 for BF16 + // Preserve hierarchical structure: ((16, 4), 32) = 16*4*32 = 2048 elements + Tensor tStS_P = tStS.compose( + make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + tStS_P.data() = warp_uniform( + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose( + make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32))); + + // Select TMEM operation based on K dimension (number of columns) + // For K=64: 64 rows × 64 cols = 4,096 elements → use 16dp32b4x + // For K=128: 64 rows × 128 cols = 8,192 elements → use 16dp32b8x using TMEM_LOAD = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_LOAD_32dp32b8x, - SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + size<1>(TileShapeQK{}) == _64{}, + SM100_TMEM_LOAD_16dp32b16x, // For K=64: 4,096 elements + SM100_TMEM_LOAD_16dp32b8x>; // For K=128: 8,192 elements + using TMEM_STORE = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_STORE_32dp32b16x, - SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE_V = - SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + size<1>(TileShapeQK{}) == _64{}, + SM100_TMEM_STORE_16dp32b8x, // For K=64, BF16: 2,048 elements + SM100_TMEM_STORE_16dp32b8x>; - int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + // TMEM_STORE_V: Store row statistics (old_max, new_max) for online softmax + // correction Always 64 rows × 2 cols = 128 FP32 elements + using TMEM_STORE_V = SM100_TMEM_STORE_16dp32b2x; auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); @@ -604,7 +617,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); } - + ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); + row_max = ::fmax(row_max, shuffled_row_max); ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); @@ -661,18 +675,12 @@ struct Sm100FmhaGenMainloopWarpspecialized { order_s.arrive(); } - // this prevents register spills in fp16 - if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { - if (i == size(tTMEM_LOADrS) - 6) { - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); - } - } } // tmem_store(reg_S8) -> op_P - CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); - CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + // CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + // CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); cutlass::arch::fence_view_async_tmem_store(); @@ -716,6 +724,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { if (final_call) { // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); + // Sync threads 0 and 16 to get the sum of row_sum between them + row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; @@ -802,18 +812,24 @@ struct Sm100FmhaGenMainloopWarpspecialized { // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 - const int kCorrectionTileSize = 32 / sizeof(ElementOut); + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + // TODO: load all values - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -901,16 +917,18 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_16dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_16dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout( + make_shape(make_shape(_16{}, _4{}), Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; @@ -992,10 +1010,11 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tStS_v = tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{}))); - using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + using TMEM_LOAD_V = + SM100_TMEM_LOAD_16dp32b2x; // 4x32 threads with 2 cols of 32b elem auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 3ce07debff..42057aeafe 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -687,13 +687,13 @@ def _execute_cutlass_blackwell_attn_varlen( sm_scale, num_groups, ) - for dtype in [torch.bfloat16, torch.float8_e4m3fn] + for dtype in [torch.bfloat16] for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True, False] for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] for head_dim in [128] - for sm_scale in [None, 1.0 / head_dim] + for sm_scale in [None] for num_groups in [1, 2] ] ) @@ -711,6 +711,14 @@ def test_decode( ) -> None: seqlen_q = 1 causal = True + if True: + print( + f"Running test_decode with params: " + f"dtype={dtype}, seqlen_k={seqlen_k}, batch_size={batch_size}, " + f"is_mqa={is_mqa}, window_size={window_size}, head_dim={head_dim}, " + f"sm_scale={sm_scale}, q_heads={q_heads}" + ) + self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q,