Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,16 @@ struct Flash_fwd_params : public Qkv_params {
bool pack_gqa;

int * __restrict__ tile_count_semaphore;
// int * __restrict__ num_m_blocks_ptr;
int * __restrict__ num_m_blocks_ptr;
// int * __restrict__ num_n_blocks_ptr;
int * __restrict__ num_splits_dynamic_ptr;
int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual
int * __restrict__ num_nheads_in_l2_ptr;
bool skip_scheduler_metadata_computation;
bool varlen_sort_batches;
int tile_count_semaphore_offset;
bool head_swizzle;
bool prepare_varlen_pdl;

int arch;
int num_sm;
Expand Down
85 changes: 65 additions & 20 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ PyObject* PyInit__C(void)
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992

void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
Expand Down Expand Up @@ -250,13 +252,15 @@ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
if (params.is_bf16) {
#ifndef FLASHATTENTION_DISABLE_HDIM64
if (params.d <= 64) {
#ifndef FLASHATTENTION_DISABLE_HDIMDIFF64
if constexpr (Arch == 90) {
if (params.dv > 256) {
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
} else if (params.dv > 64) {
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
}
#endif
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
#endif
Expand All @@ -268,11 +272,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
#endif
#ifndef FLASHATTENTION_DISABLE_HDIM192
if (params.d <= 192) {
#ifndef FLASHATTENTION_DISABLE_HDIMDIFF192
if constexpr (Arch == 90) {
if (params.dv <= 128) {
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
}
#endif
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
#endif
Expand All @@ -283,13 +289,15 @@ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_FP16
#ifndef FLASHATTENTION_DISABLE_HDIM64
if (params.d <= 64) {
#ifndef FLASHATTENTION_DISABLE_HDIMDIFF64
if constexpr (Arch == 90) {
if (params.dv > 256) {
return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
} else if (params.dv > 64) {
return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
}
#endif
return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
#endif
Expand All @@ -301,11 +309,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
#endif
#ifndef FLASHATTENTION_DISABLE_HDIM192
if (params.d <= 192) {
#ifndef FLASHATTENTION_DISABLE_HDIMDIFF192
if constexpr (Arch == 90) {
if (params.dv <= 128) {
return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
}
#endif
return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
#endif
Expand All @@ -329,11 +339,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
#endif
#ifndef FLASHATTENTION_DISABLE_HDIM192
if (params.d <= 192) {
#ifndef FLASHATTENTION_DISABLE_HDIMDIFF192
if constexpr (Arch == 90) {
if (params.dv <= 128) {
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
}
#endif
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
}
#endif
Expand Down Expand Up @@ -525,8 +537,7 @@ mha_fwd_get_scheduler_metadata(
bool has_softcap,
int64_t num_splits,
std::optional<bool> pack_gqa_,
int64_t sm_margin
) {
int64_t sm_margin) {

TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
Expand Down Expand Up @@ -585,8 +596,9 @@ mha_fwd_get_scheduler_metadata(
params.page_size = page_size.has_value() ? page_size.value() : 1;
params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);

bool const use_dynamic_split = params.b <= 992;
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
bool const use_prepare_varlen = true;
params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;
params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
Expand All @@ -603,18 +615,35 @@ mha_fwd_get_scheduler_metadata(
// This needs to be set after get_num_splits
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
if (scheduler_needs_semaphore || use_dynamic_split) {
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template
params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template
if (scheduler_needs_semaphore || use_prepare_varlen) {
int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers
int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;
if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }
if(params.head_swizzle) { num_prepare_batch_vectors += 1; }
int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);
int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;
// printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors);
tile_count_semaphore = torch::empty(
{int(scheduler_needs_semaphore) + tile_count_semaphore_offset},
opts.dtype(torch::kInt32));
// {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}
params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;
params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + b_rounded * 2 : nullptr;
// params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;
params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;
if (scheduler_needs_semaphore) {
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>() + tile_count_semaphore_offset;
} else {
params.tile_count_semaphore = nullptr;
}
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
}

if (params.num_splits_dynamic_ptr) {
if (use_prepare_varlen) {
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
Expand Down Expand Up @@ -938,11 +967,11 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
}
}

// 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
bool const use_dynamic_split = is_varlen && params.b <= 992;
bool const use_prepare_varlen = is_varlen;
params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;
// Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);

params.pagedkv_tma = get_pagedkv_tma(params);
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
Expand All @@ -955,8 +984,17 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
bool const scheduler_needs_semaphore = params.arch >= 90
? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
: ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
if (scheduler_needs_semaphore || use_dynamic_split) {
int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template
params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template
if (scheduler_needs_semaphore || use_prepare_varlen) {
int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers
int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;
if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }
if(params.head_swizzle) { num_prepare_batch_vectors += 1; }
int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);
int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;
int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset;
// printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size);
params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
if (scheduler_metadata_.has_value()) {
at::Tensor scheduler_metadata = scheduler_metadata_.value();
Expand All @@ -968,15 +1006,22 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
} else {
tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
}
if (scheduler_needs_semaphore && !use_dynamic_split) {
if (scheduler_needs_semaphore && !use_prepare_varlen) {
tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
}
params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
// {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}
params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;
params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;
params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + b_rounded * 2 : nullptr;
// params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;
params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;
params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() + tile_count_semaphore_offset : nullptr;
params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later
}

if (q_v_.has_value()) {
TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256.");
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
"q_v is only supported for fp16 and bf16 data type");
TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
Expand Down Expand Up @@ -1134,7 +1179,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
} else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
// need to zero out the semaphore in this case
tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_();
}
} else if (total_q > 0 && num_heads_k > 0) {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
Expand Down
3 changes: 2 additions & 1 deletion hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def _flash_attn_forward(
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0):
sm_margin=0,
):
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
Expand Down
11 changes: 8 additions & 3 deletions hopper/flash_fwd_combine_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class FlashAttnFwdCombine {
int const* const cu_seqlens = nullptr;
int const* const seqused = nullptr;
int const* const num_splits_dynamic_ptr = nullptr;
int const* const varlen_batch_idx_ptr = nullptr;
int* const semaphore_to_reset = nullptr;
};

Expand All @@ -164,6 +165,7 @@ class FlashAttnFwdCombine {
int const* const cu_seqlens = nullptr;
int const* const seqused = nullptr;
int const* const num_splits_dynamic_ptr = nullptr;
int const* const varlen_batch_idx_ptr = nullptr;
int* const semaphore_to_reset = nullptr;
};

Expand All @@ -187,7 +189,9 @@ class FlashAttnFwdCombine {
args.cu_seqlens,
args.seqused,
args.num_splits_dynamic_ptr,
args.semaphore_to_reset
args.varlen_batch_idx_ptr,
args.semaphore_to_reset,

};
}

Expand All @@ -203,8 +207,9 @@ class FlashAttnFwdCombine {
int const thread_idx = threadIdx.x;
int const m_block = blockIdx.x;
int const k_block = blockIdx.y;
int const batch = blockIdx.z;
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
int const maybe_virtual_batch = blockIdx.z;
int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch;
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial);

if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
cutlass::arch::wait_on_dependent_grids();
Expand Down
2 changes: 1 addition & 1 deletion hopper/flash_fwd_combine_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool e
{params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
static_cast<float*>(params.softmax_lse_ptr),
{_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore
};

typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
Expand Down
Loading