Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ at::Tensor dispatch_fmha_gen_fwd(

return DISPATCH_ELEMENT_TYPE(q.scalar_type(), Element, [&] {
return DISPATCH_KERNEL_TYPE(static_cast<int>(kernel_type), KType, [&] {
GenRunner<Element, KType, Shape<_128, _128, _128>, Shape<_1, _1, _1>>
GenRunner<Element, KType, Shape<_64, _128, _128>, Shape<_1, _1, _1>>
runner;
return runner.fmha_fwd(q, k, v, seqlen_kv, batch_idx);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
TAs...>, TMs...>) {

return TiledMMA<MMA_Atom<
Expand All @@ -101,10 +101,10 @@ to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
M, N,
a_major,
b_major,
a_neg,
b_neg>,
a_major,
b_major,
a_neg,
b_neg>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
Expand All @@ -125,4 +125,75 @@ void warpgroup_reg_set() {
}
}

} // namespace cutlass::fmha::collective
} // namespace cutlass::fmha::collective

namespace constexpr_type_map {
/*
* The following utility type_traits allow mapping constexpr variable to type at
* compile time.
* The default return type defined for each map would be returned if queried key
* does not exist in the map.
*/

template <auto keyVal, typename _valueT>
struct kValTyPair {
static constexpr auto key = keyVal;
using valueT = _valueT;
};

template <typename Default, typename FirstMapping, typename... OtherMapping>
struct kValTyMap {
template <auto QueryKey>
using query = std::conditional_t<
QueryKey == FirstMapping::key,
typename FirstMapping::valueT,
typename kValTyMap<Default, OtherMapping...>::template query<QueryKey>>;
};

template <typename Default, typename LastMapping>
struct kValTyMap<Default, LastMapping> {
template <auto QueryKey>
using query = std::conditional_t<
QueryKey == LastMapping::key,
typename LastMapping::valueT,
Default>;
};

} // namespace constexpr_type_map

namespace constexpr_constexpr_map {

template <auto keyVal, auto valueVal>
struct kValValPair {
static constexpr auto key = keyVal;
static constexpr auto value = valueVal;
};

template <auto Default, typename FirstMapping, typename... OtherMapping>
struct kValValMap {
using ValType = std::add_const_t<decltype(Default)>;
static_assert(
std::is_same_v<ValType, decltype(FirstMapping::value)>,
"Map value type mismatch");
static_assert(
(std::is_same_v<ValType, decltype(OtherMapping::value)> && ...),
"Map value type mismatch");
template <decltype(FirstMapping::key) QueryKey>
static constexpr decltype(FirstMapping::value) query =
(QueryKey == FirstMapping::key)
? FirstMapping::value
: kValValMap<Default, OtherMapping...>::template query<QueryKey>;
};

template <auto Default, typename LastMapping>
struct kValValMap<Default, LastMapping> {
using ValType = std::add_const_t<decltype(Default)>;
static_assert(
std::is_same_v<ValType, decltype(LastMapping::value)>,
"Map value type mismatch");
template <decltype(LastMapping::key) QueryKey>
static constexpr decltype(LastMapping::value) query =
(QueryKey == LastMapping::key) ? LastMapping::value : Default;
};

} // namespace constexpr_constexpr_map
Original file line number Diff line number Diff line change
Expand Up @@ -526,35 +526,48 @@ struct Sm100FmhaGenMainloopWarpspecialized {
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {

Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor tScS =
typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);

Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);

Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));

auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));

// local changes
// Each thread owns a single row
Tensor tStS_v =
tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{})));
tStS_v.data() =
uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v =
tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{})));

auto tilePlikeFP32 = _32{}; // 32 for FP32
// size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};

// tilePlikeFP32 = 64/4*2 = 32 for BF16
// Preserve hierarchical structure: ((16, 4), 32) = 16*4*32 = 2048 elements
Tensor tStS_P = tStS.compose(
make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32)));
tStS_P.data() = warp_uniform(
uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(
make_layout(make_shape(make_shape(_16{}, _4{}), tilePlikeFP32)));

// Select TMEM operation based on K dimension (number of columns)
// For K=64: 64 rows × 64 cols = 4,096 elements → use 16dp32b4x
// For K=128: 64 rows × 128 cols = 8,192 elements → use 16dp32b8x
using TMEM_LOAD = conditional_t<
size<1>(TileShapeQK{}) < _128{},
SM100_TMEM_LOAD_32dp32b8x,
SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
size<1>(TileShapeQK{}) == _64{},
SM100_TMEM_LOAD_16dp32b16x, // For K=64: 4,096 elements
SM100_TMEM_LOAD_16dp32b8x>; // For K=128: 8,192 elements

using TMEM_STORE = conditional_t<
size<1>(TileShapeQK{}) < _128{},
SM100_TMEM_STORE_32dp32b16x,
SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V =
SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
size<1>(TileShapeQK{}) == _64{},
SM100_TMEM_STORE_16dp32b8x, // For K=64, BF16: 2,048 elements
SM100_TMEM_STORE_16dp32b8x>;

int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
// TMEM_STORE_V: Store row statistics (old_max, new_max) for online softmax
// correction Always 64 rows × 2 cols = 128 FP32 elements
using TMEM_STORE_V = SM100_TMEM_STORE_16dp32b2x;

auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Expand Down Expand Up @@ -604,7 +617,8 @@ struct Sm100FmhaGenMainloopWarpspecialized {
row_max = ::fmax(row_max, row_max_2);
row_max = ::fmax(row_max, row_max_3);
}

ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16);
row_max = ::fmax(row_max, shuffled_row_max);
ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;

Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
Expand Down Expand Up @@ -661,18 +675,12 @@ struct Sm100FmhaGenMainloopWarpspecialized {
order_s.arrive();
}

// this prevents register spills in fp16
if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {
if (i == size(tTMEM_LOADrS) - 6) {
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));
}
}
}

// tmem_store(reg_S8) -> op_P
CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));
// CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
// CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4);

cutlass::arch::fence_view_async_tmem_store();

Expand Down Expand Up @@ -716,6 +724,8 @@ struct Sm100FmhaGenMainloopWarpspecialized {
if (final_call) {
// re-acquire the S part in the final step
pipeline_s.consumer_wait(pipeline_s_consumer_state);
// Sync threads 0 and 16 to get the sum of row_sum between them
row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16);

Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;
Expand Down Expand Up @@ -802,18 +812,24 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
// TODO: load all values

using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32,
SM100_TMEM_LOAD_16dp32b16x
, SM100_TMEM_LOAD_16dp32b8x>; // 4x32 threads with 64 cols of 32b elem

typename CollectiveMmaPV::TiledMma mma;
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOgO = mma.get_slice(0).partition_C(gO);

Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));

Tensor tOtO_i = tOtO.compose(make_layout(
make_shape(make_shape(_16{}, _4{}), Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(
make_shape(make_shape(_16{}, _4{}), Int<kCorrectionTileSize>{})));
Tensor tOgO_i = tOgO.compose(make_layout(
make_shape(make_shape(_16{}, _4{}), Int<kCorrectionTileSize>{})));

Tensor tOtO0 = tOtO_i;
tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0);
Expand Down Expand Up @@ -901,16 +917,18 @@ struct Sm100FmhaGenMainloopWarpspecialized {
// good values would be either 32 or 64
const int kCorrectionTileSize = 32;

using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem
using TMEM_LOAD = SM100_TMEM_LOAD_16dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_16dp32b16x; // 4x32 threads with 64 cols of 32b elem

typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);

Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));

Tensor tOtO_i = tOtO.compose(make_layout(
make_shape(make_shape(_16{}, _4{}), Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(
make_shape(make_shape(_16{}, _4{}), Int<kCorrectionTileSize>{})));

tOtO_i.data() = tOtO_i.data().get() + tmem_O;

Expand Down Expand Up @@ -992,10 +1010,11 @@ struct Sm100FmhaGenMainloopWarpspecialized {
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);

Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tStS_v = tStS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(make_shape(_16{}, _4{}), _4{})));

using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
using TMEM_LOAD_V =
SM100_TMEM_LOAD_16dp32b2x; // 4x32 threads with 2 cols of 32b elem

auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);
auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,13 +687,13 @@ def _execute_cutlass_blackwell_attn_varlen(
sm_scale,
num_groups,
)
for dtype in [torch.bfloat16, torch.float8_e4m3fn]
for dtype in [torch.bfloat16]
for seqlen_k in [64, 128, 256, 1024]
for batch_size in [1, 2]
for is_mqa in [True, False]
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
for head_dim in [128]
for sm_scale in [None, 1.0 / head_dim]
for sm_scale in [None]
for num_groups in [1, 2]
]
)
Expand All @@ -711,6 +711,14 @@ def test_decode(
) -> None:
seqlen_q = 1
causal = True
if True:
print(
f"Running test_decode with params: "
f"dtype={dtype}, seqlen_k={seqlen_k}, batch_size={batch_size}, "
f"is_mqa={is_mqa}, window_size={window_size}, head_dim={head_dim}, "
f"sm_scale={sm_scale}, q_heads={q_heads}"
)

self._execute_cutlass_blackwell_attn_dense(
batch_size,
seqlen_q,
Expand Down
Loading