diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 5b3d124627a..cb8cda73934 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -3,9 +3,18 @@ ******************************************************************************/ #include -#include -#include -#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include @@ -35,9 +44,14 @@ PyObject* PyInit__C(void) } } -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_DEVICE(x) STABLE_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) STABLE_TORCH_CHECK(x.sizes() == torch::standalone::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) STABLE_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +using torch::stable::Tensor; +using torch::stable::cuda::CUDAGuard; +using torch::stable::cuda::getCurrentDeviceProperties; +using torch::standalone::ScalarType; void set_params_fprop(Flash_fwd_params ¶ms, // sizes @@ -51,10 +65,10 @@ void set_params_fprop(Flash_fwd_params ¶ms, const size_t d, const size_t d_rounded, // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - at::Tensor out, + const Tensor q, + const Tensor k, + const Tensor v, + Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, @@ -71,8 +85,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Reset the parameters params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; - params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + params.is_bf16 = q.dtype() == torch::standalone::kBFloat16; + params.is_e4m3 = q.dtype() == torch::standalone::kFloat8_e4m3fn; // Set the pointers and strides. params.q_ptr = q.data_ptr(); @@ -130,9 +144,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; - TORCH_CHECK(p_dropout < 1.f); + STABLE_TORCH_CHECK(p_dropout < 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + STABLE_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. @@ -151,11 +165,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.arch = getCurrentDeviceProperties()->major * 10 + getCurrentDeviceProperties()->minor; + params.num_sm = getCurrentDeviceProperties()->multiProcessorCount - sm_margin; #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + STABLE_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif } @@ -171,14 +185,14 @@ void set_params_dgrad(Flash_bwd_params ¶ms, const size_t d, const size_t d_rounded, // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor out, - const at::Tensor dout, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, + const Tensor q, + const Tensor k, + const Tensor v, + const Tensor out, + const Tensor dout, + Tensor dq, + Tensor dk, + Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, @@ -248,7 +262,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // HEADDIM_SWITCH(params.d, [&] { // run_mha_fwd_(params, stream); // }); - TORCH_CHECK(params.num_splits >= 1); + STABLE_TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { @@ -319,7 +333,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); + STABLE_TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } } else { @@ -346,7 +360,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #else - TORCH_CHECK(false, "This flash attention build does not support FP8."); + STABLE_TORCH_CHECK(false, "This flash attention build does not support FP8."); #endif } }); @@ -380,7 +394,7 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool ena } } #else - TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + STABLE_TORCH_CHECK(false, "This flash attention build does not support combine kernels."); #endif } @@ -490,7 +504,7 @@ inline int round_up_headdimv(int head_size) { } // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available -at::Tensor +Tensor mha_fwd_get_scheduler_metadata( int64_t batch_size, int64_t max_seqlen_q, @@ -499,13 +513,13 @@ mha_fwd_get_scheduler_metadata( int64_t num_heads_k, int64_t headdim, int64_t headdim_v, - at::ScalarType qkv_dtype, - at::Tensor seqused_k, // b - std::optional cu_seqlens_q_, // b+1 - std::optional cu_seqlens_k_, // b+1 - std::optional cu_seqlens_k_new_, // b+1 - std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional leftpad_k_, // b + ScalarType qkv_dtype, + Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b std::optional page_size, int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, @@ -518,14 +532,14 @@ mha_fwd_get_scheduler_metadata( int64_t sm_margin ) { - TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, + STABLE_TORCH_CHECK(qkv_dtype == ScalarType::Half || qkv_dtype == ScalarType::BFloat16 || qkv_dtype == ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STABLE_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // Reset the parameters Flash_fwd_params params{}; - params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; - params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; + params.is_bf16 = qkv_dtype == ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == ScalarType::Float8_e4m3fn; params.b = batch_size; params.seqlen_q = max_seqlen_q; params.seqlen_k = max_seqlen_k; @@ -568,8 +582,8 @@ mha_fwd_get_scheduler_metadata( params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.arch = getCurrentDeviceProperties()->major * 10 + getCurrentDeviceProperties()->minor; + params.num_sm = getCurrentDeviceProperties()->multiProcessorCount - sm_margin; params.softcap = has_softcap ? 1.0f : 0.0f; params.page_size = page_size.has_value() ? page_size.value() : 1; @@ -587,14 +601,14 @@ mha_fwd_get_scheduler_metadata( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + CUDAGuard device_guard{(char)seqused_k.get_device()}; auto opts = seqused_k.options(); // This needs to be set after get_num_splits - at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + tile_count_semaphore = torch::stable::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::standalone::kInt32)); if (scheduler_needs_semaphore) { if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -609,7 +623,7 @@ mha_fwd_get_scheduler_metadata( auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = torch::stable::cuda::getCurrentCUDAStream().stream(); prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -624,31 +638,31 @@ mha_fwd_get_scheduler_metadata( // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple -mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional cu_seqlens_q_, // b+1 - std::optional cu_seqlens_k_, // b+1 - std::optional cu_seqlens_k_new_, // b+1 - std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. +std::tuple +mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k std::optional max_seqlen_k_, - std::optional page_table_, // (b_k, max_num_pages_per_seq) - std::optional kv_batch_idx_, // b. indices to index into the KV cache - std::optional leftpad_k_, // b - std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional seqlens_rotary_, // b - std::optional q_descale_, // (b, h_k), not (b, h) - std::optional k_descale_, // (b, h_k) - std::optional v_descale_, // (b, h_k) + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) double softmax_scale, bool is_causal, int64_t window_size_left, @@ -656,58 +670,58 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql int64_t attention_chunk, double softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional scheduler_metadata_, // (b + 1) + std::optional scheduler_metadata_, // (b + 1) int64_t num_splits, std::optional pack_gqa_, int64_t sm_margin ) { - auto dprops = at::cuda::getCurrentDeviceProperties(); + auto dprops = getCurrentDeviceProperties(); bool is_sm8x = dprops->major >= 8; - TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + STABLE_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.scalar_type(); - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, + STABLE_TORCH_CHECK(q_type == ScalarType::Half || q_type == ScalarType::BFloat16 || q_type == ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); if (dprops->major < 9) { - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + STABLE_TORCH_CHECK(q_type == ScalarType::Half || q_type == ScalarType::BFloat16, "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); } - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + STABLE_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STABLE_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - at::Tensor page_table; + Tensor page_table; const bool paged_KV = page_table_.has_value(); if (paged_KV) { page_table = page_table_.value(); CHECK_DEVICE(page_table); - TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + STABLE_TORCH_CHECK(page_table.dtype() == torch::standalone::kInt32, "page_table must have dtype torch.int32"); + STABLE_TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); } - at::Tensor cu_seqlens_q; + Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + STABLE_TORCH_CHECK(cu_seqlens_q.dtype() == torch::standalone::kInt32, "cu_seqlens_q must have dtype torch.int32"); + STABLE_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } - at::Tensor cu_seqlens_k; + Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); - TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); - TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + STABLE_TORCH_CHECK(cu_seqlens_k.dtype() == torch::standalone::kInt32, "cu_seqlens_k must have dtype torch.int32"); + STABLE_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + STABLE_TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + STABLE_TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } auto const sizes = q.sizes(); @@ -725,19 +739,19 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + STABLE_TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } int const max_headdim = get_max_headdim(); - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STABLE_TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STABLE_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { - TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + STABLE_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " "or (Q/K <= 64 and V <= 512)."); - TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + STABLE_TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + STABLE_TORCH_CHECK(q_type == ScalarType::Half || q_type == ScalarType::BFloat16, "HeaddimV > 256 requires fp16 and bf16 data type"); } } @@ -778,20 +792,20 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql if (seqused_q_.has_value()){ auto seqused_q = seqused_q_.value(); - TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + STABLE_TORCH_CHECK(seqused_q.dtype() == torch::standalone::kInt32, "seqused_q must have dtype int32"); CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); CHECK_SHAPE(seqused_q, batch_size); } if (seqused_k_.has_value()) { auto seqused_k = seqused_k_.value(); - TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + STABLE_TORCH_CHECK(seqused_k.dtype() == torch::standalone::kInt32, "seqused_k must have dtype int32"); CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + STABLE_TORCH_CHECK(leftpad_k.dtype() == torch::standalone::kInt32, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); } @@ -799,21 +813,21 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + STABLE_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif - int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + int const alignment = q_type == torch::standalone::kFloat8_e4m3fn ? 16 : 8; + STABLE_TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + STABLE_TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); - auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; - at::Tensor out; + auto out_type = q_type == ScalarType::Float8_e4m3fn ? ScalarType::BFloat16 : q_type; + Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + STABLE_TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { @@ -821,8 +835,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } } else { out = !is_varlen_q - ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) - : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); + ? torch::stable::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::stable::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -833,13 +847,13 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + CUDAGuard device_guard{(char)q.get_device()}; - at::Tensor softmax_lse; + Tensor softmax_lse; if (!is_varlen_q) { - softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + softmax_lse = torch::stable::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::standalone::kFloat)); } else { - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + softmax_lse = torch::stable::empty({num_heads, total_q}, opts.dtype(torch::standalone::kFloat)); } Flash_fwd_params params; @@ -878,24 +892,24 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.num_pages = num_pages; if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma - at::Tensor k_new, v_new; - TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); - TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); - TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); - at::Tensor cu_seqlens_k_new; + Tensor k_new, v_new; + STABLE_TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + STABLE_TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + STABLE_TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + Tensor cu_seqlens_k_new; bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); if (is_varlen_k_new) { cu_seqlens_k_new = cu_seqlens_k_new_.value(); CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); - TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); + STABLE_TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::standalone::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); } k_new = k_new_.value(); v_new = v_new_.value(); - TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); - TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); + STABLE_TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); + STABLE_TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); - TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); - TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); @@ -936,7 +950,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); // This needs to be set after get_num_splits - at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic // We don't use the persistent scheduler if Split and not Varlen bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) @@ -945,14 +959,14 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { - at::Tensor scheduler_metadata = scheduler_metadata_.value(); + Tensor scheduler_metadata = scheduler_metadata_.value(); CHECK_DEVICE(scheduler_metadata); CHECK_SHAPE(scheduler_metadata, metadata_size); CHECK_CONTIGUOUS(scheduler_metadata); - TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + STABLE_TORCH_CHECK(scheduler_metadata.dtype() == torch::standalone::kInt32, "scheduler_metadata must have dtype int32"); tile_count_semaphore = scheduler_metadata; } else { - tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); + tile_count_semaphore = torch::stable::empty({metadata_size}, opts.dtype(torch::standalone::kInt32)); } if (scheduler_needs_semaphore && !use_dynamic_split) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing @@ -962,14 +976,14 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + STABLE_TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + STABLE_TORCH_CHECK(q_type == ScalarType::Half || q_type == ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + STABLE_TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + Tensor q_v = q_v_.value(); + STABLE_TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); CHECK_DEVICE(q_v); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); } else { @@ -985,31 +999,31 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } if (rotary_cos_.has_value()) { - TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + STABLE_TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + STABLE_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + STABLE_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); if (paged_KV) { - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + STABLE_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); } CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + STABLE_TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + STABLE_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + STABLE_TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); + Tensor seqlens_rotary = seqlens_rotary_.value(); CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + STABLE_TORCH_CHECK(seqlens_rotary.dtype() == torch::standalone::kInt32, "seqlens_rotary must have dtype torch.int32"); CHECK_SHAPE(seqlens_rotary, batch_size); params.seqlens_rotary = seqlens_rotary.data_ptr(); } @@ -1020,22 +1034,22 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql if (kv_batch_idx_.has_value()) { auto kv_batch_idx = kv_batch_idx_.value(); CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); - TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + STABLE_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::standalone::kInt32, "kv_batch_idx must have dtype int32"); params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); } - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; + Tensor out_accum, softmax_lse_accum; + auto outaccum_type = ScalarType::Float; if (params.num_splits > 1) { - TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + STABLE_TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); - softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::stable::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::stable::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(torch::standalone::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); - softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); + out_accum = torch::stable::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::stable::empty({params.num_splits, num_heads, total_q}, opts.dtype(torch::standalone::kFloat)); } params.is_fp32 = false; params.oaccum_ptr = out_accum.data_ptr(); @@ -1047,7 +1061,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } - if (q_type == at::ScalarType::Float8_e4m3fn) { + if (q_type == ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); CHECK_DEVICE(q_descale); @@ -1081,29 +1095,29 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + STABLE_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + STABLE_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif #ifdef FLASHATTENTION_DISABLE_SPLIT - TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + STABLE_TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA - TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + STABLE_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV - TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + STABLE_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV - TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + STABLE_TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); #endif if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = torch::stable::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); if (params.num_splits > 1) { - if (out_type == at::ScalarType::BFloat16) { + if (out_type == ScalarType::BFloat16) { // Since we want output in BF16. Otherwise fwd_combine will output to FP16 params.is_bf16 = true; } @@ -1120,7 +1134,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + tile_count_semaphore.index({torch::stable::indexing::Slice(0, 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1159,7 +1173,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } #endif #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); + STABLE_TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } else { #ifndef FLASHATTENTION_DISABLE_HDIM64 @@ -1190,20 +1204,20 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( - at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k - at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k - std::optional cu_seqlens_q_, // b+1 - std::optional cu_seqlens_k_, // b+1 - std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. +std::tuple mha_bwd( + Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, std::optional max_seqlen_k_, double softmax_scale, @@ -1216,50 +1230,50 @@ std::tuplemajor >= 8; - TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + STABLE_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.dtype(); - TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16, + STABLE_TORCH_CHECK(q_type == torch::standalone::kFloat16 || q_type == torch::standalone::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); + STABLE_TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); + STABLE_TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); + STABLE_TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); + STABLE_TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - at::Tensor cu_seqlens_q; + Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + STABLE_TORCH_CHECK(cu_seqlens_q.dtype() == torch::standalone::kInt32, "cu_seqlens_q must have dtype torch.int32"); + STABLE_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } - at::Tensor cu_seqlens_k; + Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + STABLE_TORCH_CHECK(cu_seqlens_k.dtype() == torch::standalone::kInt32, "cu_seqlens_k must have dtype torch.int32"); + STABLE_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); } // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + STABLE_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif auto const sizes = q.sizes(); @@ -1272,11 +1286,11 @@ std::tuple= seqlen_k - 1) { window_size_left = -1; } @@ -1286,7 +1300,7 @@ std::tuplemajor * 10 + at::cuda::getCurrentDeviceProperties()->minor; + int const arch = getCurrentDeviceProperties()->major * 10 + getCurrentDeviceProperties()->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; // Very important that these match the kernel configs @@ -1336,86 +1350,86 @@ std::tuple(); // Will be zero'ed out in the backward preprocess kernel - at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); + Tensor dq_semaphore = torch::stable::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::standalone::kInt32)); params.dq_semaphore = dq_semaphore.data_ptr(); if (num_heads_k != num_heads && params.deterministic) { // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + Tensor dk_semaphore = torch::stable::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::standalone::kInt32)); + Tensor dv_semaphore = torch::stable::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::standalone::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + STABLE_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + STABLE_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif if (total_q > 0 && total_k > 0 && num_heads_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = torch::stable::cuda::getCurrentCUDAStream().stream(); run_mha_bwd(params, stream); } else if (total_k > 0 && num_heads_k > 0) { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. @@ -1487,25 +1501,25 @@ std::tuple -mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads - std::optional out_, // batch_size x seqlen x num_heads x head_size - std::optional out_dtype_ +std::tuple +mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ ) { - auto dprops = at::cuda::getCurrentDeviceProperties(); + auto dprops = getCurrentDeviceProperties(); bool is_sm8x = dprops->major >= 8; - TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); + STABLE_TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); auto out_partial_type = out_partial.scalar_type(); - TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type"); - TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type"); + STABLE_TORCH_CHECK(out_partial_type == ScalarType::Float, "Attention combine function only support fp32 data type"); + STABLE_TORCH_CHECK(lse_partial.scalar_type() == ScalarType::Float, "Attention combine function only support fp32 data type"); CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); - TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); + STABLE_TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); const auto sizes = out_partial.sizes(); @@ -1514,15 +1528,15 @@ mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); + STABLE_TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); int const alignment = 4; - at::Tensor out_partial_padded; - auto pad = [](at::Tensor x, int alignment) { - return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment})); + Tensor out_partial_padded; + auto pad = [](Tensor x, int alignment) { + return x.size(-1) % alignment == 0 ? x : torch::stable::nn::functional::pad(x, torch::standalone::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment})); }; out_partial_padded = pad(out_partial, alignment); @@ -1530,31 +1544,31 @@ mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen const int head_size = round_multiple(head_size_og, alignment); auto opts = out_partial.options(); - at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); - TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); - at::Tensor out; + ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); + STABLE_TORCH_CHECK(out_type == ScalarType::Float || out_type == ScalarType::BFloat16 || out_type == ScalarType::Half, "Output type must be FP32, FP16 or BF16"); + Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.scalar_type() == out_type); + STABLE_TORCH_CHECK(out.scalar_type() == out_type); CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + STABLE_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); if (head_size_og % alignment != 0) { - out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); + out = torch::stable::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); } } else { - out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); + out = torch::stable::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()}; + CUDAGuard device_guard{(char)out_partial.get_device()}; - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2); + auto softmax_lse = torch::stable::empty({batch_size, num_heads, seqlen}, opts.dtype(torch::standalone::kFloat)).transpose(1, 2); Flash_fwd_params params {}; // Need to reset the params to set everything to zero - params.is_fp32 = out_type == at::ScalarType::Float; - params.is_bf16 = out_type == at::ScalarType::BFloat16; + params.is_fp32 = out_type == ScalarType::Float; + params.is_bf16 = out_type == ScalarType::BFloat16; params.oaccum_ptr = out_partial_padded.data_ptr(); params.softmax_lseaccum_ptr = lse_partial.data_ptr(); params.o_ptr = out.data_ptr(); @@ -1574,23 +1588,23 @@ mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); - params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.arch = getCurrentDeviceProperties()->major * 10 + getCurrentDeviceProperties()->minor; if (seqlen > 0 && batch_size > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = torch::stable::cuda::getCurrentCUDAStream().stream(); run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } - at::Tensor out_padded = out; + Tensor out_padded = out; if (head_size_og % alignment != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + out = out.index({"...", torch::stable::indexing::Slice(torch::stable::indexing::None, head_size_og)}); // if (out_.has_value()) { out_.value().copy_(out); } } return {out, softmax_lse}; } -TORCH_LIBRARY(flash_attn_3, m) { +STABLE_TORCH_LIBRARY(flash_attn_3, m) { m.def("fwd(" "Tensor q," "Tensor k," @@ -1681,7 +1695,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "int sm_margin = 0) -> Tensor"); } -TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { +STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { m.impl("fwd", &mha_fwd); m.impl("bwd", &mha_bwd); m.impl("fwd_combine", &mha_combine);