diff --git a/.gitignore b/.gitignore index 1f1f8028863..4a98f484d59 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ var/ # IDE-related .idea/ +.vscode/ # Dev venv diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 30134990d68..78adbc21ce6 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1389,6 +1389,9 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + dcp_rank=None, + dcp_world_size=None, + query_base_positions=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1445,6 +1448,14 @@ def flash_attn_varlen_func( The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ + # Context parallelism parameters are only supported in Flash Attention 3 + # For Flash Attention 2, these parameters are ignored with a warning + if dcp_rank is not None or dcp_world_size is not None or query_base_positions is not None: + import warnings + warnings.warn("Context parallelism parameters (dcp_rank, dcp_world_size, query_base_positions) " + "are only supported in Flash Attention 3. These parameters will be ignored.", + UserWarning) + return FlashAttnVarlenFunc.apply( q, k, diff --git a/hopper/flash.h b/hopper/flash.h index 28997613dc6..a05c122dba4 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -161,6 +161,11 @@ struct Flash_fwd_params : public Qkv_params { // The S extra matrix, (num_heads) void *__restrict__ s_aux_ptr; + + // Context parallelism parameters for MLA decode + int dcp_rank; + int dcp_world_size; + int *__restrict__ query_base_positions; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 6a4f0e6ee65..5f57264b599 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -701,7 +701,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int num_splits, std::optional pack_gqa_, int const sm_margin, - std::optional &s_aux_ // (h) + std::optional &s_aux_, // (h) + int dcp_rank, + int dcp_world_size, + std::optional &query_base_positions_ ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -1148,6 +1151,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.s_aux_ptr = nullptr; } + // Set context parallelism parameters + params.dcp_rank = dcp_rank; + params.dcp_world_size = dcp_world_size; + if (query_base_positions_.has_value()) { + params.query_base_positions = query_base_positions_.value().data_ptr(); + } else { + params.query_base_positions = nullptr; + } + #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif @@ -1280,7 +1292,10 @@ std::vector mha_bwd( int window_size_right, float const softcap, bool const deterministic, - int const sm_margin) { + int const sm_margin, + int dcp_rank, + int dcp_world_size, + std::optional &query_base_positions_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1513,6 +1528,15 @@ std::vector mha_bwd( params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now + // Set context parallelism parameters + params.dcp_rank = dcp_rank; + params.dcp_world_size = dcp_world_size; + if (query_base_positions_.has_value()) { + params.query_base_positions = query_base_positions_.value().data_ptr(); + } else { + params.query_base_positions = nullptr; + } + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); // Will be zero'ed out in the backward preprocess kernel diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index ad2c515f9dd..b86dd82063f 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -52,7 +52,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int num_splits, std::optional pack_gqa_, int const sm_margin, - std::optional &s_aux_ + std::optional &s_aux_, + int dcp_rank, + int dcp_world_size, + std::optional &query_base_positions_ ); // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available @@ -120,7 +123,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int num_splits," " bool? pack_gqa," " int sm_margin," - " Tensor? s_aux) -> Tensor[]"); + " Tensor? s_aux," + " int dcp_rank," + " int dcp_world_size," + " Tensor? query_base_positions) -> Tensor[]"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ops.def("get_scheduler_metadata(" diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index d3150fbb671..27609413ece 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -49,7 +49,10 @@ def _flash_attn_forward( num_splits=1, pack_gqa=None, sm_margin=0, - s_aux=None): + s_aux=None, + dcp_rank=0, + dcp_world_size=1, + query_base_positions=None): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -61,6 +64,9 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) + # Handle context parallelism parameters + query_base_positions = maybe_contiguous(query_base_positions) + out, softmax_lse, *rest = flash_attn_3_cuda.fwd( q, k, @@ -95,7 +101,10 @@ def _flash_attn_forward( num_splits, pack_gqa, sm_margin, - s_aux + s_aux, + dcp_rank, + dcp_world_size, + query_base_positions ) return out, softmax_lse, *rest @@ -122,9 +131,15 @@ def _flash_attn_backward( softcap=0.0, deterministic=False, sm_margin=0, + dcp_rank=0, + dcp_world_size=1, + query_base_positions=None, ): # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + # Handle context parallelism parameters + query_base_positions = maybe_contiguous(query_base_positions) + dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, @@ -148,6 +163,9 @@ def _flash_attn_backward( softcap, deterministic, sm_margin, + dcp_rank, + dcp_world_size, + query_base_positions, ) return dq, dk, dv, softmax_d @@ -351,6 +369,9 @@ def forward( deterministic=False, sm_margin=0, s_aux=None, + dcp_rank=0, + dcp_world_size=1, + query_base_positions=None, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -380,6 +401,9 @@ def forward( pack_gqa=pack_gqa, sm_margin=sm_margin, s_aux=s_aux, + dcp_rank=dcp_rank, + dcp_world_size=dcp_world_size, + query_base_positions=query_base_positions, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -391,6 +415,10 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin + # Save context parallelism parameters for backward pass + ctx.dcp_rank = dcp_rank + ctx.dcp_world_size = dcp_world_size + ctx.query_base_positions = query_base_positions return out, softmax_lse @staticmethod @@ -419,11 +447,14 @@ def backward(ctx, dout, *args): ctx.softcap, ctx.deterministic, ctx.sm_margin, + ctx.dcp_rank, + ctx.dcp_world_size, + ctx.query_base_positions, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -582,6 +613,9 @@ def flash_attn_varlen_func( deterministic=False, sm_margin=0, s_aux=None, + dcp_rank=0, + dcp_world_size=1, + query_base_positions=None, ): return FlashAttnVarlenFunc.apply( q, @@ -604,6 +638,9 @@ def flash_attn_varlen_func( deterministic, sm_margin, s_aux, + dcp_rank, + dcp_world_size, + query_base_positions, ) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 50a2d7fb80f..37d4a59cdea 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -412,6 +412,9 @@ struct CollectiveMainloopFwdSm90 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; + int const dcp_rank = 0; + int const dcp_world_size = 1; + int const* const query_base_positions = nullptr; }; // Device side kernel params @@ -469,6 +472,9 @@ struct CollectiveMainloopFwdSm90 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; + int const dcp_rank = 0; + int const dcp_world_size = 1; + int const* const query_base_positions = nullptr; }; static Params @@ -584,7 +590,8 @@ struct CollectiveMainloopFwdSm90 { args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_S_aux}; + args.ptr_S_aux, + args.dcp_rank, args.dcp_world_size, args.query_base_positions}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -1093,7 +1100,9 @@ struct CollectiveMainloopFwdSm90 { // But we subtract n_offset for consistency in mask calculations flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/, - params.qhead_per_khead_divmod + params.qhead_per_khead_divmod, + params.dcp_rank, params.dcp_world_size, + params.query_base_positions != nullptr ? params.query_base_positions[bidb] : 0 ); float softcap_val = params.softcap_val; diff --git a/hopper/mask.h b/hopper/mask.h index 02d046268cf..b84ee98da4d 100644 --- a/hopper/mask.h +++ b/hopper/mask.h @@ -23,11 +23,15 @@ struct Mask { int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; cutlass::FastDivmod const qhead_per_khead_divmod; + // Context parallelism parameters for MLA decode + int const dcp_rank, dcp_world_size; + int const query_base_position; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, - cutlass::FastDivmod const &qhead_per_khead_divmod) + cutlass::FastDivmod const &qhead_per_khead_divmod, + const int dcp_rank = 0, const int dcp_world_size = 1, const int query_base_position = 0) : thread_idx(thread_idx) , seqlen_q(seqlen_q) , seqlen_k(seqlen_k) @@ -35,6 +39,9 @@ struct Mask { , window_size_right(window_size_right) , sink_token_length(sink_token_length) , qhead_per_khead_divmod(qhead_per_khead_divmod) + , dcp_rank(dcp_rank) + , dcp_world_size(dcp_world_size) + , query_base_position(query_base_position) { }; @@ -89,8 +96,12 @@ struct Mask { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + // For context parallelism, adjust causal mask based on global query position + int const global_row_idx = query_base_position + row_idx; + int const kv_offset = (dcp_world_size > 1) ? (seqlen_k * dcp_rank / dcp_world_size) : 0; + int const adjusted_causal_row_offset = causal_row_offset - kv_offset; int const col_limit_right = !Seqlenk_mask - ? row_idx + causal_row_offset + ? ((dcp_world_size > 1) ? global_row_idx + adjusted_causal_row_offset : row_idx + causal_row_offset) : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 06de7fd17b7..5e72593a4ef 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -146,6 +146,10 @@ def flash_attn_varlen_func( # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, + # Context parallelism parameters for MLA decode + dcp_rank=None, + dcp_world_size=None, + query_base_positions=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -255,6 +259,11 @@ def flash_attn_varlen_func( ) elif fa_version == 3: assert alibi_slopes is None, "Alibi is not supported in FA3" + + # Handle context parallelism parameters - convert None to default values + dcp_rank_val = dcp_rank if dcp_rank is not None else 0 + dcp_world_size_val = dcp_world_size if dcp_world_size is not None else 1 + out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k, v, None, None, # k_new, v_new @@ -279,7 +288,10 @@ def flash_attn_varlen_func( num_splits, None, # pack_gqa 0, # sm_margin - s_aux # s_aux + s_aux, # s_aux + dcp_rank_val, # dcp_rank + dcp_world_size_val, # dcp_world_size + query_base_positions # query_base_positions ) else: raise ValueError(f"Unsupported FA version: {fa_version}")