Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 194 additions & 124 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/

// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <Python.h>
#include <torch/nn/functional.h>
#include <torch/version.h> // For TORCH_VERSION* macros
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

Expand All @@ -17,44 +15,25 @@
#include "heuristics.h"
#include "cuda_check.h"

// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
// This is so that we can pass in torch.dtype as a parameter to the function.
#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I had to remove this because we're no longer using PyBind. This will probably break with PyTorch < 2.4, is this fine, or do we want a backup solution for older PyTorch versions?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fine, we won't support pytorhc < 2.4


#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace pybind11::detail {

template <>
struct type_caster<at::ScalarType> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
// PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
// cannot be default-initialized, we provide this constructor to explicitly
// initialize that field. The value doesn't matter as it will be overwritten
// after a successful call to load.
type_caster() : value(at::kFloat) {}
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (THPDtype_Check(obj)) {
value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
return true;
}
return false;
}
static handle cast(
const at::ScalarType& src,
return_value_policy /* policy */,
handle /* parent */) {
return Py_NewRef(torch::getTHPDtype(src));
}
};

} // namespace pybind11::detail

#endif
extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so consisting of this file
in this extension, so that the TORCH_LIBRARY static initializers
below are run. */
PyObject* PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
Expand Down Expand Up @@ -513,30 +492,30 @@ inline int round_up_headdimv(int head_size) {
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
at::Tensor
mha_fwd_get_scheduler_metadata(
int batch_size,
int max_seqlen_q,
int max_seqlen_k,
int num_heads,
int num_heads_k,
int headdim,
int headdim_v,
int64_t batch_size,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
int64_t num_heads,
int64_t num_heads_k,
int64_t headdim,
int64_t headdim_v,
at::ScalarType qkv_dtype,
const at::Tensor &seqused_k, // b
std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor> &leftpad_k_, // b
std::optional<int> page_size,
int max_seqlen_k_new, // 0 means we're not appending new KV
at::Tensor seqused_k, // b
std::optional<at::Tensor> cu_seqlens_q_, // b+1
std::optional<at::Tensor> cu_seqlens_k_, // b+1
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<at::Tensor> leftpad_k_, // b
std::optional<int64_t> page_size,
int64_t max_seqlen_k_new, // 0 means we're not appending new KV
bool is_causal,
int window_size_left,
int window_size_right,
int attention_chunk,
int64_t window_size_left,
int64_t window_size_right,
int64_t attention_chunk,
bool has_softcap,
int num_splits,
int64_t num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin
int64_t sm_margin
) {

TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
Expand Down Expand Up @@ -645,42 +624,42 @@ mha_fwd_get_scheduler_metadata(
// h: num_heads
// h_k: num_heads_k
// d: head_size
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor> &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> cu_seqlens_q_, // b+1
std::optional<at::Tensor> cu_seqlens_k_, // b+1
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int64_t> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor> &leftpad_k_, // b
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &seqlens_rotary_, // b
std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor> &k_descale_, // (b, h_k)
std::optional<at::Tensor> &v_descale_, // (b, h_k)
float const softmax_scale,
std::optional<int64_t> max_seqlen_k_,
std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
std::optional<at::Tensor> leftpad_k_, // b
std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> seqlens_rotary_, // b
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor> k_descale_, // (b, h_k)
std::optional<at::Tensor> v_descale_, // (b, h_k)
double softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
int attention_chunk,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
int num_splits,
int64_t window_size_left,
int64_t window_size_right,
int64_t attention_chunk,
double softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
int64_t num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin
int64_t sm_margin
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -1211,29 +1190,30 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// h: num_heads
// h_k: num_heads_k
// d: head_size
std::vector<at::Tensor> mha_bwd(
const at::Tensor &dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
const at::Tensor &v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
const at::Tensor &out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
std::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
std::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
std::optional<at::Tensor> &dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
std::optional<int> max_seqlen_k_,
float const softmax_scale,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
std::optional<at::Tensor> dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
std::optional<at::Tensor> dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
std::optional<at::Tensor> dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
std::optional<at::Tensor> cu_seqlens_q_, // b+1
std::optional<at::Tensor> cu_seqlens_k_, // b+1
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int64_t> max_seqlen_q_,
std::optional<int64_t> max_seqlen_k_,
double softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const deterministic,
int const sm_margin) {
int64_t window_size_left,
int64_t window_size_right,
double softcap,
bool deterministic,
int64_t sm_margin
) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -1507,9 +1487,9 @@ std::vector<at::Tensor> mha_bwd(
return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
}

std::vector<at::Tensor>
mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
std::tuple<at::Tensor, at::Tensor>
mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads
std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
std::optional<at::ScalarType> out_dtype_
) {
Expand Down Expand Up @@ -1610,10 +1590,100 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x
return {out, softmax_lse};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
TORCH_LIBRARY(flash_attn_3, m) {
m.def("fwd("
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor(k_new!)? k_new,"
"Tensor(v_new!)? v_new,"
"Tensor? q_v,"
"Tensor(out!)? out,"
"Tensor? cu_seqlens_q,"
"Tensor? cu_seqlens_k,"
"Tensor? cu_seqlens_k_new,"
"Tensor? seqused_q,"
"Tensor? seqused_k,"
"int? max_seqlen_q,"
"int? max_seqlen_k,"
"Tensor? page_table,"
"Tensor? kv_batch_idx,"
"Tensor? leftpad_k,"
"Tensor? rotary_cos,"
"Tensor? rotary_sin,"
"Tensor? seqlens_rotary,"
"Tensor? q_descale,"
"Tensor? k_descale,"
"Tensor? v_descale,"
"float softmax_scale,"
"bool is_causal,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"int attention_chunk = 0,"
"float softcap = 0.0,"
"bool is_rotary_interleaved = False,"
"Tensor? scheduler_metadata = None,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
m.def("bwd("
"Tensor dout,"
"Tensor q,"
"Tensor k,"
"Tensor v,"
"Tensor out,"
"Tensor softmax_lse,"
"Tensor(dq!)? dq,"
"Tensor(dk!)? dk,"
"Tensor(dv!)? dv,"
"Tensor? cu_seqlens_q,"
"Tensor? cu_seqlens_k,"
"Tensor? seqused_q,"
"Tensor? seqused_k,"
"int? max_seqlen_q,"
"int? max_seqlen_k,"
"float softmax_scale,"
"bool is_causal,"
"int window_size_left = -1,"
"int window_size_right = -1,"
"float softcap = 0.0,"
"bool deterministic = False,"
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("fwd_combine("
"Tensor out_partial,"
"Tensor lse_partial,"
"Tensor(out!)? out = None,"
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
m.def("get_scheduler_metadata("
"int batch_size,"
"int max_seqlen_q,"
"int max_seqlen_k,"
"int num_heads,"
"int num_heads_k,"
"int headdim,"
"int headdim_v,"
"ScalarType qkv_dtype,"
"Tensor seqused_k,"
"Tensor? cu_seqlens_q,"
"Tensor? cu_seqlens_k,"
"Tensor? cu_seqlens_k_new,"
"Tensor? seqused_q,"
"Tensor? leftpad_k,"
"int? page_size,"
"int max_seqlen_k_new,"
"bool is_causal,"
"int window_size_left,"
"int window_size_right,"
"int attention_chunk,"
"bool has_softcap = False,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> Tensor");
}

TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
m.impl("fwd", &mha_fwd);
m.impl("bwd", &mha_bwd);
m.impl("fwd_combine", &mha_combine);
m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
}
3 changes: 2 additions & 1 deletion hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_3_cuda
import flash_attn_3._C # Registers operators with PyTorch

# isort: on

flash_attn_3_cuda = torch.ops.flash_attn_3

def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Expand Down
Loading