Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def to_cute_block_sparse_tensors(
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
if not is_block_sparsity_enabled(tensors):
return None

(
mask_block_cnt,
mask_block_idx,
Expand Down
6 changes: 5 additions & 1 deletion flash_attn/cute/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@

logger = logging.getLogger(__name__)
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
_handler.setFormatter(
logging.Formatter(
"%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
)
logger.addHandler(_handler)
logger.setLevel(logging.DEBUG)

Expand Down
202 changes: 156 additions & 46 deletions flash_attn/cute/flash_fwd_sm100.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def _flash_attn_fwd(
causal, window_size_left, window_size_right, mask_mod
)

# In fake mode (CPU-only compilation), use a fake stream placeholder.
requested_use_clc_scheduler = utils._get_use_clc_scheduler_default()
requested_disable_2cta = utils._get_disable_2cta_default()

current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)

# SM80/SM120: uses SM80 MMA, 128 threads (4 warps)
Expand Down Expand Up @@ -517,6 +519,7 @@ def _flash_attn_fwd(

use_2cta_instrs = (
arch // 10 in [10, 11]
and not requested_disable_2cta
and not causal
and not local
and not is_split_kv
Expand Down Expand Up @@ -621,6 +624,7 @@ def _flash_attn_fwd(
q_subtile_factor,
mma_pv_is_rs,
intra_wg_overlap,
requested_use_clc_scheduler,
fa_logging.get_fa_log_level(),
)
if compile_key not in _flash_attn_fwd.compile_cache:
Expand Down Expand Up @@ -728,6 +732,7 @@ def _flash_attn_fwd(
is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
q_subtile_factor=q_subtile_factor,
use_2cta_instrs=use_2cta_instrs,
use_clc_scheduler=requested_use_clc_scheduler,
)
elif arch // 10 == 12:
# SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity
Expand Down Expand Up @@ -1056,8 +1061,10 @@ def _flash_attn_bwd(
dKV_swapAB = False
AtomLayoutMdQ = 1
AtomLayoutNdKV = 1
requested_disable_2cta = utils._get_disable_2cta_default()
disable_2cta = (
score_mod is not None
requested_disable_2cta
or score_mod is not None
or score_mod_bwd is not None
or mask_mod is not None
)
Expand Down
Loading