Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def to_cute_block_sparse_tensors(
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
if not is_block_sparsity_enabled(tensors):
return None

(
mask_block_cnt,
mask_block_idx,
Expand Down
231 changes: 189 additions & 42 deletions flash_attn/cute/flash_fwd_sm100.py

Large diffs are not rendered by default.

17 changes: 14 additions & 3 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _flash_attn_fwd(
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
aux_tensors: Optional[list[torch.Tensor]] = None,
sched_stages: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for FlashAttention.

Expand Down Expand Up @@ -339,6 +340,9 @@ def _flash_attn_fwd(
causal, window_size_left, window_size_right, mask_mod
)

requested_use_clc_scheduler = utils._get_use_clc_scheduler_default()
requested_disable_2cta = utils._get_disable_2cta_default()

current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

fwd_cfg = FwdConfig(128, 128, True, True) # default
Expand Down Expand Up @@ -393,6 +397,7 @@ def _flash_attn_fwd(

use_2cta_instrs = (
arch // 10 in [10, 11]
and not requested_disable_2cta
and not causal
and not local
and not is_split_kv
Expand Down Expand Up @@ -496,6 +501,8 @@ def _flash_attn_fwd(
q_subtile_factor,
mma_pv_is_rs,
intra_wg_overlap,
requested_use_clc_scheduler,
sched_stages,
fa_logging.get_fa_log_level(),
)
if compile_key not in _flash_attn_fwd.compile_cache:
Expand Down Expand Up @@ -569,8 +576,8 @@ def _flash_attn_fwd(
is_local=local,
is_split_kv=is_split_kv,
pack_gqa=pack_gqa,
tile_m=tile_m,
tile_n=tile_n,
m_block_size=tile_m,
n_block_size=tile_n,
q_stage=q_stage,
is_persistent=not causal
and not local
Expand All @@ -582,6 +589,8 @@ def _flash_attn_fwd(
has_aux_tensors=aux_tensors is not None,
paged_kv_non_tma=page_size not in [None, 128],
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
use_clc_scheduler=requested_use_clc_scheduler,
sched_stages=sched_stages,
q_subtile_factor=q_subtile_factor,
use_2cta_instrs=use_2cta_instrs,
)
Expand Down Expand Up @@ -811,8 +820,10 @@ def _flash_attn_bwd(
dKV_swapAB = False
AtomLayoutMdQ = 1
AtomLayoutNdKV = 1
requested_disable_2cta = utils._get_disable_2cta_default()
disable_2cta = (
local
requested_disable_2cta
or local
or score_mod is not None
or score_mod_bwd is not None
or mask_mod is not None
Expand Down
Loading