diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 58188137777..5921c374f6b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -2,10 +2,8 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#include +#include #include -#include // For TORCH_VERSION* macros #include #include @@ -17,44 +15,25 @@ #include "heuristics.h" #include "cuda_check.h" -// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 -// This is so that we can pass in torch.dtype as a parameter to the function. -#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4) - -#include -#include - -namespace pybind11::detail { - - template <> - struct type_caster { - public: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); - // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType - // cannot be default-initialized, we provide this constructor to explicitly - // initialize that field. The value doesn't matter as it will be overwritten - // after a successful call to load. - type_caster() : value(at::kFloat) {} - bool load(handle src, bool) { - PyObject* obj = src.ptr(); - if (THPDtype_Check(obj)) { - value = reinterpret_cast(obj)->scalar_type; - return true; - } - return false; - } - static handle cast( - const at::ScalarType& src, - return_value_policy /* policy */, - handle /* parent */) { - return Py_NewRef(torch::getTHPDtype(src)); - } - }; - -} // namespace pybind11::detail -#endif +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -513,30 +492,30 @@ inline int round_up_headdimv(int head_size) { // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( - int batch_size, - int max_seqlen_q, - int max_seqlen_k, - int num_heads, - int num_heads_k, - int headdim, - int headdim_v, + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, at::ScalarType qkv_dtype, - const at::Tensor &seqused_k, // b - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &leftpad_k_, // b - std::optional page_size, - int max_seqlen_k_new, // 0 means we're not appending new KV + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, - int window_size_left, - int window_size_right, - int attention_chunk, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, bool has_softcap, - int num_splits, + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin + int64_t sm_margin ) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, @@ -645,42 +624,42 @@ mha_fwd_get_scheduler_metadata( // h: num_heads // h_k: num_heads_k // d: head_size -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &seqlens_rotary_, // b - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + double softmax_scale, bool is_causal, - int window_size_left, - int window_size_right, - int attention_chunk, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional &scheduler_metadata_, // (b + 1) - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin + int64_t sm_margin ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -1211,29 +1190,30 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::vector mha_bwd( - const at::Tensor &dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k - const at::Tensor &out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - std::optional max_seqlen_k_, - float const softmax_scale, +std::tuple mha_bwd( + at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + double softmax_scale, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const deterministic, - int const sm_margin) { + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1507,9 +1487,9 @@ std::vector mha_bwd( return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } -std::vector -mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads +std::tuple +mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads std::optional out_, // batch_size x seqlen x num_heads x head_size std::optional out_dtype_ ) { @@ -1610,10 +1590,100 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x return {out, softmax_lse}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); - m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); +TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new," + "Tensor(v_new!)? v_new," + "Tensor? q_v," + "Tensor(out!)? out," + "Tensor? cu_seqlens_q," + "Tensor? cu_seqlens_k," + "Tensor? cu_seqlens_k_new," + "Tensor? seqused_q," + "Tensor? seqused_k," + "int? max_seqlen_q," + "int? max_seqlen_k," + "Tensor? page_table," + "Tensor? kv_batch_idx," + "Tensor? leftpad_k," + "Tensor? rotary_cos," + "Tensor? rotary_sin," + "Tensor? seqlens_rotary," + "Tensor? q_descale," + "Tensor? k_descale," + "Tensor? v_descale," + "float softmax_scale," + "bool is_causal," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq," + "Tensor(dk!)? dk," + "Tensor(dv!)? dv," + "Tensor? cu_seqlens_q," + "Tensor? cu_seqlens_k," + "Tensor? seqused_q," + "Tensor? seqused_k," + "int? max_seqlen_q," + "int? max_seqlen_k," + "float softmax_scale," + "bool is_causal," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q," + "Tensor? cu_seqlens_k," + "Tensor? cu_seqlens_k_new," + "Tensor? seqused_q," + "Tensor? leftpad_k," + "int? page_size," + "int max_seqlen_k_new," + "bool is_causal," + "int window_size_left," + "int window_size_right," + "int attention_chunk," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); + m.impl("fwd_combine", &mha_combine); + m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 06782fa409b..cfb8881b4b2 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -7,10 +7,11 @@ # isort: off # We need to import the CUDA kernels after importing torch -import flash_attn_3_cuda +import flash_attn_3._C # Registers operators with PyTorch # isort: on +flash_attn_3_cuda = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x diff --git a/hopper/setup.py b/hopper/setup.py index 7ed8abce15f..c15c438f56c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -539,13 +539,14 @@ def nvcc_threads_args(): ext_modules.append( CUDAExtension( - name="flash_attn_3_cuda", + name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + feature_args, + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, + py_limited_api=True, ) ) @@ -654,4 +655,5 @@ def run(self): "packaging", "ninja", ], + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 80d4dc0c15c..7e2e6fd87a8 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -134,7 +134,7 @@ def test_flash_attn_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]