Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ struct Flash_fwd_params : public Qkv_params {

int arch;
int num_sm;

// The S extra matrix, (num_heads)
void *__restrict__ s_aux_ptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 14 additions & 1 deletion hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin
int const sm_margin,
std::optional<const at::Tensor> &s_aux_ // (h)
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -1091,6 +1092,18 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
}
}

if(s_aux_.has_value()) {
auto s_aux = s_aux_.value();
TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16,
"We only support bf16 dtype for S extra.");
CHECK_DEVICE(s_aux);
CHECK_SHAPE(s_aux, num_heads);
CHECK_CONTIGUOUS(s_aux);
params.s_aux_ptr = s_aux.data_ptr();
} else {
params.s_aux_ptr = nullptr;
}

#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
#endif
Expand Down
6 changes: 4 additions & 2 deletions hopper/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin
int const sm_margin,
std::optional<const at::Tensor> &s_aux_
);

// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
Expand Down Expand Up @@ -118,7 +119,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
" int sm_margin,"
" Tensor? s_aux) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

ops.def("get_scheduler_metadata("
Expand Down
20 changes: 16 additions & 4 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _flash_attn_forward(
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0):
sm_margin=0,
s_aux=None):
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
Expand Down Expand Up @@ -94,6 +95,7 @@ def _flash_attn_forward(
num_splits,
pack_gqa,
sm_margin,
s_aux
)
return out, softmax_lse, *rest

Expand Down Expand Up @@ -233,7 +235,7 @@ def backward(ctx, dout, *args):
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.deterministic,
ctx.deterministic,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None, None
Expand All @@ -257,6 +259,7 @@ def forward(
pack_gqa=None,
deterministic=False,
sm_margin=0,
s_aux=None,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
Expand All @@ -281,6 +284,7 @@ def forward(
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.save_for_backward(q, k, v, out, softmax_lse)
Expand Down Expand Up @@ -319,7 +323,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -346,6 +350,7 @@ def forward(
pack_gqa=None,
deterministic=False,
sm_margin=0,
s_aux=None,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
Expand Down Expand Up @@ -374,6 +379,7 @@ def forward(
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
Expand Down Expand Up @@ -417,7 +423,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -490,6 +496,7 @@ def flash_attn_func(
pack_gqa=None,
deterministic=False,
sm_margin=0,
s_aux=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -550,6 +557,7 @@ def flash_attn_func(
pack_gqa,
deterministic,
sm_margin,
s_aux,
)


Expand All @@ -573,6 +581,7 @@ def flash_attn_varlen_func(
pack_gqa=None,
deterministic=False,
sm_margin=0,
s_aux=None,
):
return FlashAttnVarlenFunc.apply(
q,
Expand All @@ -594,6 +603,7 @@ def flash_attn_varlen_func(
pack_gqa,
deterministic,
sm_margin,
s_aux,
)


Expand Down Expand Up @@ -631,6 +641,7 @@ def flash_attn_with_kvcache(
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
s_aux=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
Expand Down Expand Up @@ -757,6 +768,7 @@ def flash_attn_with_kvcache(
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
Expand Down
10 changes: 10 additions & 0 deletions hopper/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class FlashAttnFwdSm90 {
static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;

using SmemLayoutSAux = typename CollectiveMainloop::SmemLayoutSAux;

// Mainloop derived types
using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
Expand Down Expand Up @@ -295,6 +297,14 @@ class FlashAttnFwdSm90 {
CollectiveMainloop mainloop;
CollectiveEpilogue epilogue;

const int num_heads = get<2>(params.mainloop.shape_Q);
Tensor gS_aux = make_tensor(make_gmem_ptr(params.mainloop.ptr_S_aux), make_shape(num_heads));
Tensor sS_aux = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_s_aux.data()), SmemLayoutSAux{});

if(params.mainloop.ptr_S_aux && threadIdx.x < num_heads) {
sS_aux(threadIdx.x) = gS_aux(threadIdx.x);
}

// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
Expand Down
8 changes: 5 additions & 3 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
using ElementS = cutlass::bfloat16_t;

// Can't use structured binding since it's not compatible with constexpr
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg);
Expand All @@ -52,8 +53,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
using CollectiveMainloop = std::conditional_t<
Arch >= 90,
flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>
flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor, ElementS>,
flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split, ElementS>
>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;

Expand Down Expand Up @@ -127,7 +128,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
params.kv_batch_idx,
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
params.seqused_q, params.seqused_k,
params.leftpad_k, params.seqlens_rotary
params.leftpad_k, params.seqlens_rotary,
static_cast<ElementS const*>(params.s_aux_ptr)
};
typename CollectiveEpilogue::Arguments epilogue_args {
static_cast<ElementOut*>(params.o_ptr),
Expand Down
8 changes: 6 additions & 2 deletions hopper/mainloop_fwd_sm80.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace cute;

template <int kNWarps, int Stages, bool Q_in_regs, class TileShape_MNK_, int kHeadDimV, class Element_, class ElementAccum_, class ArchTag_,
bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_,
bool PackGQA_, bool Split_>
bool PackGQA_, bool Split_, class ElementSAux_>
struct CollectiveMainloopFwdSm80 {

static constexpr int kStages = Stages;
Expand All @@ -34,6 +34,7 @@ struct CollectiveMainloopFwdSm80 {
using TileShape_MNK_PV = Shape<decltype(get<0>(TileShape_MNK{})), Int<kHeadDimV>, decltype(get<1>(TileShape_MNK{}))>;
using Element = Element_;
using ElementAccum = ElementAccum_;
using ElementSAux = ElementSAux_;
using ArchTag = ArchTag_;
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;
static constexpr bool Is_causal = Is_causal_;
Expand Down Expand Up @@ -213,6 +214,7 @@ struct CollectiveMainloopFwdSm80 {
int const* const seqused_k = nullptr;
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
};

// Device side kernel params
Expand Down Expand Up @@ -258,6 +260,7 @@ struct CollectiveMainloopFwdSm80 {
int const* const seqused_k = nullptr;
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
};

static Params
Expand Down Expand Up @@ -297,7 +300,8 @@ struct CollectiveMainloopFwdSm80 {
!Split ? 1 : args.num_splits,
args.kv_batch_idx,
args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary};
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary,
args.ptr_S_aux};
}

template <typename SharedStorage, typename FrgTensorO, typename Softmax>
Expand Down
Loading