Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
18a0cae
add sink test to attention_ref
jayhshah Jun 6, 2025
684c203
embed sink test into main flash attn test script
jayhshah Jun 6, 2025
49be39d
change local block positioning for fast path
jayhshah Jun 16, 2025
bf2ad39
add tma gqa modifications
jayhshah Jun 17, 2025
07b6e77
fix exploding build
jayhshah Jun 17, 2025
3f8cfbe
try to fix exploding build again
jayhshah Jun 17, 2025
8d4de3c
prune other hdims to keep build stable
jayhshah Jun 17, 2025
2145fea
tweak tile size for causal not local case
jayhshah Jun 17, 2025
f293967
split compilation
jayhshah Jun 17, 2025
95cc109
fix error with varlen q
jayhshah Jun 17, 2025
e1217ab
split compilation for root setup
jayhshah Aug 3, 2025
1a3cd09
renable use one mma wg
jayhshah Aug 3, 2025
5d58d4b
update for hdim diff
jayhshah Aug 4, 2025
6a2c39a
include comments on how to enable hdim diff in cmakelists
jayhshah Aug 4, 2025
0529f4f
fix test variable
jayhshah Aug 4, 2025
a4e9b0f
update pack gqa heuristic
jayhshah Aug 4, 2025
53b576f
fix comment
jayhshah Aug 4, 2025
a763abf
split disable hdim diff macro into 64 and 192, enable 192 by default …
jayhshah Aug 5, 2025
786873d
add assert to check using Hopper kernels with s_aux
jayhshah Aug 5, 2025
e5935eb
more logical placement of assert checks for hdim diff in flash api
jayhshah Aug 5, 2025
323bb43
cherrypick
LucasWilkinson Aug 7, 2025
4bca7a3
typo fix
LucasWilkinson Aug 7, 2025
330171b
review comments
LucasWilkinson Aug 8, 2025
4e44ba2
review comment
LucasWilkinson Aug 8, 2025
6a3d9c8
Merge remote-tracking branch 'upstream/main' into lwilkinson/aux-fast…
LucasWilkinson Aug 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,27 +178,48 @@ endif ()
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_bf16*_sm90.cu")
# Add these for hdim diff cases
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
# "hopper/instantiations/flash_fwd_hdim64_256_bf16*_sm90.cu"
# "hopper/instantiations/flash_fwd_hdim64_512_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_128_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_fp16*_sm90.cu")
# Add these for hdim diff cases
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
# "hopper/instantiations/flash_fwd_hdim64_256_fp16*_sm90.cu"
# "hopper/instantiations/flash_fwd_hdim64_512_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_128_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim64_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim96_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim128_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_e4m3*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim256_e4m3*_sm90.cu")
# Add these for hdim diff cases (192 only)
file(GLOB FA3_FP8_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
"hopper/instantiations/flash_fwd_hdim192_128_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
Expand Down Expand Up @@ -244,11 +265,17 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case
# FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_HDIMDIFF64
# FLASHATTENTION_DISABLE_HDIMDIFF192
CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED
CUTLASS_ENABLE_GDC_FOR_SM90
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
Expand Down
37 changes: 25 additions & 12 deletions hopper/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,39 @@ struct BlockMN {

static
CUTLASS_DEVICE
cute::tuple<int, int> get_n_block_min_max(
cute::tuple<int, int, int> get_n_block_min_max(
SeqlenInfo_t const& seqlen_info,
int const m_block, int const bidb, int const split_idx, int const num_splits,
int const window_size_left, int const window_size_right,
cutlass::FastDivmod const& qhead_per_khead_divmod) {

int const seqlen_k = seqlen_info.seqlen_k;
int seqlen_k = seqlen_info.seqlen_k;
int const seqlen_q = seqlen_info.seqlen_q;
int n_offset = 0;

// If local, calculate n_offset and update seqlen_k
if constexpr (Is_local) {
int m_idx_min = m_block * kBlockM;
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
// unlike previously, we don't divide by kBlockN because we want offset for seqlen_k
n_offset = std::max(int(0), m_idx_min + seqlen_k - seqlen_q - window_size_left);
// Subtract n_offset from seqlen_k for subsequent calculations such as n_block_max
// This is the actual seqlen_k processed for this m_block
seqlen_k -= n_offset;
}

int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
if constexpr (Is_causal || Is_local) {
int m_idx_max = (m_block + 1) * kBlockM;
// TODO: check off-by-1 error
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
n_block_max = std::min(n_block_max,
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
}
// Now, only adjust n_block_min if split
int n_block_min = 0;
if constexpr (Is_local) {
int m_idx_min = m_block * kBlockM;
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN);
}

// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
if constexpr (Split) {
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
Expand All @@ -45,7 +56,9 @@ struct BlockMN {
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
}
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
return {n_block_min, n_block_max};

// Return n_offset to add to KV gmem pointers and use in masks
return {n_block_min, n_block_max, n_offset};
}

static
Expand All @@ -55,12 +68,12 @@ struct BlockMN {
int const m_block, int const bidb, int const split_idx, int const num_splits,
int const window_size_left, int const window_size_right,
cutlass::FastDivmod const& qhead_per_khead_divmod) {

auto [n_block_min, n_block_max] = get_n_block_min_max(
// TODO: check logic with n_offset
auto [n_block_min, n_block_max, n_offset] = get_n_block_min_max(
seqlen_info, m_block, bidb, split_idx, num_splits,
window_size_left, window_size_right, qhead_per_khead_divmod);
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
int const idx_k_new_min = std::max(n_block_min * kBlockN + n_offset - seqlen_info.seqlen_k_og, 0);
int const idx_k_new_max = std::min(n_block_max * kBlockN + n_offset - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
int const n_block_new_min = idx_k_new_min / kBlockN;
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
Expand Down
36 changes: 22 additions & 14 deletions hopper/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace flash {
using namespace cute;

template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false, int kBlockH_=1>
struct CollectiveEpilogueFwd {

using TileShape_MNK_PV = TileShape_MNK_PV_;
Expand All @@ -32,16 +32,18 @@ struct CollectiveEpilogueFwd {
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
static constexpr bool Varlen = Varlen_;
static constexpr bool PackGQA = PackGQA_;
static constexpr bool PackGQA_TMA = PackGQA && (kBlockH_ > 1);
static constexpr bool Split = Split_;
static constexpr bool Use_smem = !(Split && !Varlen);
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && (!PackGQA || PackGQA_TMA);

static_assert(ArchTag::kMinComputeCapability >= 80);
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
static_assert(sizeof(Element) <= 2);

static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
static constexpr int kBlockH = kBlockH_;

static constexpr bool LargeHeadDimV = kHeadDimV > 256;

Expand Down Expand Up @@ -83,7 +85,10 @@ struct CollectiveEpilogueFwd {
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
// ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
using ShapeOPackedTMA = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<Int<kBlockH>, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
using ShapeOPacked = std::conditional_t<PackGQA && !PackGQA_TMA,
cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>,
ShapeOPackedTMA>;
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
Expand All @@ -110,7 +115,7 @@ struct CollectiveEpilogueFwd {
Use_TMA_O,
decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeOPackedTMA{}, StrideOPacked{}),
SmemLayoutOTMA{},
select<0, 1>(TileShape_MNK_PV{}),
_1{})), // no mcast for O
Expand Down Expand Up @@ -158,19 +163,13 @@ struct CollectiveEpilogueFwd {

static Params
to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = [&]{
if constexpr (Use_TMA_O) {
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
} else {
return nullptr;
}
}();
// If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
auto const shape_O_packed = cute::conditional_return<!PackGQA>(
args.shape_O,
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
make_shape(
make_shape(cute::conditional_return<PackGQA_TMA>(Int<kBlockH>{}, qhead_per_khead), get<0>(args.shape_O)),
get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
);
auto const stride_O_packed = cute::conditional_return<!PackGQA>(
args.stride_O,
Expand All @@ -180,6 +179,15 @@ struct CollectiveEpilogueFwd {
args.stride_O_partial,
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
);
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), shape_O_packed, stride_O_packed);
TMA_O tma_store_O = [&]{
if constexpr (Use_TMA_O) {
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
} else {
return nullptr;
}
}();

// If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
select<0, 2, 3, 4>(args.shape_O),
Expand Down Expand Up @@ -308,7 +316,7 @@ struct CollectiveEpilogueFwd {

// Step 3: Write O from smem -> gmem
if constexpr (Use_TMA_O) {
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O_packed)(_, _, bidh, bidb, split_idx);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Expand Down
Loading