diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index cefb48e7e24..48cd3a9010a 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -14,6 +14,10 @@ from cutlass.cute.runtime import from_dlpack +def ceildiv(a: int, b: int) -> int: + return (a + b - 1) // b + + # placeholder Config = type("Config", (), {}) @@ -78,6 +82,26 @@ def _check_and_expand_block( return expanded_cnt, expanded_idx +def get_block_sparse_expected_shapes( + batch_size: int, + num_head: int, + seqlen_q: int, + seqlen_k: int, + m_block_size: int, + n_block_size: int, + compute_capability: int, +) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + """Return (expected_count_shape, expected_index_shape) for block sparse normalization.""" + # TODO: This multiplier should really be q_stage, wire up in later PR + # 1 cta handles 2*tile_m rows on SM100 + m_block_size_effective = 2 * m_block_size if compute_capability == 10 else m_block_size + expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective) + expected_n_blocks = ceildiv(seqlen_k, n_block_size) + expected_count_shape = (batch_size, num_head, expected_m_blocks) + expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks) + return expected_count_shape, expected_index_shape + + def normalize_block_sparse_tensors( tensors: BlockSparseTensorsTorch, *, @@ -205,8 +229,8 @@ def _compute_sparsity( config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Computes block sparsity for fixed-length sequences.""" - n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m - n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + n_blocks_q = ceildiv(config.seqlen_q, config.tile_m) + n_blocks_k = ceildiv(config.seqlen_k, config.tile_n) # Pre-allocate output tensors full_block_cnt = torch.zeros( @@ -325,12 +349,12 @@ def _compute_varlen_sparsity( max_m_blocks = 0 for seq_idx in range(config.batch_size): seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item() - n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_q = ceildiv(seq_len_q, config.tile_m) max_m_blocks = max(max_m_blocks, n_blocks_q) # The number of K blocks is determined by the total length of all sequences. total_k_len = cu_seqlens_k[-1].item() - max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n + max_n_blocks = ceildiv(total_k_len, config.tile_n) # Pre-allocate padded output tensors full_block_cnt = torch.zeros( @@ -360,8 +384,8 @@ def _compute_varlen_sparsity( seq_end_k = cu_seqlens_k[seq_idx + 1].item() seq_len_k = seq_end_k - seq_start_k - n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m - n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + n_blocks_q = ceildiv(seq_len_q, config.tile_m) + n_blocks_k = ceildiv(seq_len_k, config.tile_n) # Global block indices are relative to the start of the entire batch tensor first_m_block_global = seq_start_q // config.tile_m diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c181f0e281f..5ed87e17d14 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,7 +1,5 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype @@ -22,10 +20,17 @@ # - bwd pass optimized for Hopper/Blackwell import math +from functools import lru_cache from typing import Optional, Tuple, Callable import torch + +@lru_cache(maxsize=None) +def _get_device_capability(): + """Cached device capability check.""" + return torch.cuda.get_device_capability()[0] + import cuda.bindings.driver as cuda import cutlass @@ -46,6 +51,7 @@ BlockSparseTensorsTorch, to_cute_block_sparse_tensors, normalize_block_sparse_tensors, + get_block_sparse_expected_shapes, ) def maybe_contiguous(x): @@ -58,6 +64,15 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" assert t.is_cuda, f"{name} must be on CUDA" +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + torch2cute_dtype_map = { torch.float16: cutlass.Float16, @@ -230,51 +245,15 @@ def _flash_attn_fwd( _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] - ( - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - learnable_sink_tensor, - ) = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) - if t is not None - else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) - ] - page_table_tensor = ( - from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) - if page_table is not None - else None - ) compute_capability = ( - torch.cuda.get_device_capability()[0] + _get_device_capability() if _compute_capability is None else _compute_capability ) assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" - - sparse_tensors = None - if block_sparse_tensors is not None: - if seqlen_q is None: - raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") - m_block_size_block = m_block_size - if compute_capability == 10: - # TODO: This multiplier should really be q_stage, wire up in later PR - # 1 cta handles 2*tile_m row - m_block_size_block = 2 * m_block_size - expected_m_blocks = (seqlen_q + m_block_size_block - 1) // m_block_size_block - expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size - block_sparse_tensors = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=(batch_size, num_head, expected_m_blocks), - expected_index_shape=(batch_size, num_head, expected_m_blocks, expected_n_blocks), - ) - sparse_tensors = to_cute_block_sparse_tensors(block_sparse_tensors) - - use_block_sparsity = sparse_tensors is not None + use_block_sparsity = block_sparse_tensors is not None if mask_mod is None: if causal: @@ -327,17 +306,6 @@ def _flash_attn_fwd( out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) - q_tensor, k_tensor, v_tensor, o_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) - for t in (q, k, v, out if not is_split_kv else out_partial) - ] - if is_split_kv: - lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 1) - elif lse is not None: - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) - else: - lse_tensor = None - # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False @@ -377,10 +345,6 @@ def _flash_attn_fwd( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." ) - cute_aux_tensors = None - if aux_tensors is not None: - cute_aux_tensors = [from_dlpack(buf).mark_layout_dynamic() for buf in aux_tensors] - compile_key = ( dtype, head_dim, @@ -409,6 +373,52 @@ def _flash_attn_fwd( page_size not in [None, 128], # paged KV non-TMA ) if compile_key not in _flash_attn_fwd.compile_cache: + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + learnable_sink_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + if t is not None + else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) + ] + page_table_tensor = ( + to_cute_tensor(page_table, assumed_align=4, leading_dim=1) + if page_table is not None + else None + ) + q_tensor, k_tensor, v_tensor, o_tensor = [ + to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial) + ] + if is_split_kv: + lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) + elif lse is not None: + lse_tensor = to_cute_tensor(lse, assumed_align=4) + else: + lse_tensor = None + + sparse_tensors = None + if block_sparse_tensors is not None: + if seqlen_q is None: + raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, compute_capability, + ) + compile_time_normalized = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + ) + sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized) + + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" assert not is_split_kv, "SplitKV not supported on SM 9.0" @@ -480,25 +490,40 @@ def _flash_attn_fwd( learnable_sink_tensor, sparse_tensors, cute_aux_tensors, + options="--enable-tvm-ffi", + ) + + # Expand block sparse tensors to match actual head count (may be broadcast from 1) + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None: + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, compute_capability, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, ) + _flash_attn_fwd.compile_cache[compile_key]( - q_tensor, - k_tensor, - v_tensor, - o_tensor, - lse_tensor, + q, + k, + v, + out if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, softmax_scale, current_stream, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - page_table_tensor, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, window_size_left, window_size_right, - learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, + learnable_sink, + normalized_block_sparse_tensors, + aux_tensors, ) if is_split_kv: _flash_attn_fwd_combine( @@ -549,7 +574,7 @@ def _flash_attn_bwd( dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - compute_capability = torch.cuda.get_device_capability()[0] + compute_capability = _get_device_capability() assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" if compute_capability == 9: @@ -747,28 +772,8 @@ def _flash_attn_bwd( ) dtype = torch2cute_dtype_map[q.dtype] - q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) - for t in (q, k, v, out, dout, dq, dk, dv) - ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=lse.ndim - 1 - ) - dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) - for t in (dq_accum, dpsum, lse_log2) - ] - if qhead_per_kvhead > 1: - dk_accum_tensor, dv_accum_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) - for t in (dk_accum, dv_accum) - ] - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1) - if t is not None - else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) - ] + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + if deterministic: dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") else: @@ -780,16 +785,19 @@ def _flash_attn_bwd( else: dK_semaphore = None dV_semaphore = None - dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ - utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) - if t is not None else None - for t in (dQ_semaphore, dK_semaphore, dV_semaphore) - ] - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: + o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] + dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ + to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) + ] + lse_tensor = to_cute_tensor(lse, assumed_align=4) + cu_seqlens_q_tensor, seqused_q_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, seqused_q) + ] fa_bwd_pre = FlashAttentionBackwardPreprocess( dtype, head_dim_v, @@ -808,16 +816,17 @@ def _flash_attn_bwd( cu_seqlens_q_tensor, seqused_q_tensor, current_stream, + options="--enable-tvm-ffi", ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, - do_tensor, - dpsum_tensor, - lse_tensor, - lse_log2_tensor, - dq_accum_tensor, - cu_seqlens_q_tensor, - seqused_q_tensor, + out, + dout, + dpsum, + lse, + lse_log2, + dq_accum, + cu_seqlens_q, + seqused_q, current_stream, ) @@ -865,6 +874,25 @@ def _flash_attn_bwd( ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: + q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ + to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) + ] + dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ + to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) + ] + if qhead_per_kvhead > 1: + dk_accum_tensor, dv_accum_tensor = [ + to_cute_tensor(t) for t in (dk_accum, dv_accum) + ] + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ + utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) + if t is not None else None + for t in (dQ_semaphore, dK_semaphore, dV_semaphore) + ] fa_bwd_sm80 = FlashAttentionBackwardSm80( dtype, head_dim, @@ -937,39 +965,48 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - window_size_left=window_size_left, - window_size_right=window_size_right, - mdQ_semaphore=dQ_semaphore_tensor, - mdK_semaphore=dK_semaphore_tensor, - mdV_semaphore=dV_semaphore_tensor, + None, # softcap - not yet supported in backward + window_size_left, + window_size_right, + dQ_semaphore_tensor, + dK_semaphore_tensor, + dV_semaphore_tensor, + options="--enable-tvm-ffi", ) _flash_attn_bwd.compile_cache[compile_key]( - q_tensor, - k_tensor, - v_tensor, - do_tensor, - lse_log2_tensor, - dpsum_tensor, - dq_accum_tensor, - dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, - dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + q, + k, + v, + dout, + lse_log2, + dpsum, + dq_accum, + dk if qhead_per_kvhead == 1 else dk_accum, + dv if qhead_per_kvhead == 1 else dv_accum, softmax_scale, current_stream, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - window_size_left=window_size_left, - window_size_right=window_size_right, - mdQ_semaphore=dQ_semaphore_tensor, - mdK_semaphore=dK_semaphore_tensor, - mdV_semaphore=dV_semaphore_tensor, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + None, # softcap - not yet supported in backward + window_size_left, + window_size_right, + dQ_semaphore, + dK_semaphore, + dV_semaphore, ) num_threads = 256 if compute_capability == 9 else 128 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dq_accum_tensor = to_cute_tensor(dq_accum) + dq_tensor = to_cute_tensor(dq) + cu_seqlens_q_tensor, seqused_q_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, seqused_q) + ] arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB @@ -983,13 +1020,14 @@ def _flash_attn_bwd( cu_seqlens_q_tensor, seqused_q_tensor, current_stream, + options="--enable-tvm-ffi", ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum_tensor, - dq_tensor, + dq_accum, + dq, softmax_scale, - cu_seqlens_q_tensor, - seqused_q_tensor, + cu_seqlens_q, + seqused_q, current_stream, ) @@ -997,6 +1035,12 @@ def _flash_attn_bwd( # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dk_accum_tensor = to_cute_tensor(dk_accum) + dk_tensor = to_cute_tensor(dk) + cu_seqlens_k_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_k, seqused_k) + ] fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) @@ -1009,13 +1053,14 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_k_tensor, current_stream, + options="--enable-tvm-ffi", ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum_tensor, - dk_tensor, + dk_accum, + dk, softmax_scale, - cu_seqlens_k_tensor, - seqused_k_tensor, + cu_seqlens_k, + seqused_k, current_stream, ) compile_key_post = ( @@ -1027,6 +1072,12 @@ def _flash_attn_bwd( dKV_swapAB, ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dv_accum_tensor = to_cute_tensor(dv_accum) + dv_tensor = to_cute_tensor(dv) + cu_seqlens_k_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_k, seqused_k) + ] fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) @@ -1039,13 +1090,14 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_k_tensor, current_stream, + options="--enable-tvm-ffi", ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum_tensor, - dv_tensor, - cutlass.Float32(1.0), - cu_seqlens_k_tensor, - seqused_k_tensor, + dv_accum, + dv, + 1.0, + cu_seqlens_k, + seqused_k, current_stream, ) @@ -1364,30 +1416,6 @@ def _flash_attn_fwd_combine( # TODO: we can deal w this by using 128 threads instead log_max_splits = max(log_max_splits, 5) - # Convert to cute tensors (using kernel-formatted tensors) - out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=4 if not is_varlen else 3 - ) - lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=lse_partial.ndim - 2 - ) - out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3 if not is_varlen else 2) - lse_tensor = ( - from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) - if lse is not None - else None - ) - - optional_tensors = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) - if t is not None - else None - for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) - ] - cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( - optional_tensors - ) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Create combine kernel configuration @@ -1407,6 +1435,28 @@ def _flash_attn_fwd_combine( ) if compile_key not in _flash_attn_fwd_combine.compile_cache: + out_partial_tensor = to_cute_tensor( + out_partial, leading_dim=4 if not is_varlen else 3 + ) + lse_partial_tensor = to_cute_tensor( + lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 + ) + out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) + lse_tensor = ( + to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) + if lse is not None + else None + ) + + optional_tensors = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + if t is not None + else None + for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) + ] + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( + optional_tensors + ) fa_combine = FlashAttentionForwardCombine( dtype=dtype, dtype_partial=dtype_partial, @@ -1441,17 +1491,17 @@ def _flash_attn_fwd_combine( num_splits_dynamic_tensor, semaphore_tensor, current_stream, + options="--enable-tvm-ffi", ) - _flash_attn_fwd_combine.compile_cache[compile_key]( - out_partial_tensor, - lse_partial_tensor, - out_tensor, - lse_tensor, - cu_seqlens_tensor, - seqused_tensor, - num_splits_dynamic_tensor, - semaphore_tensor, + out_partial, + lse_partial, + out, + lse, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, current_stream, ) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 8b5942b10d0..08e831913f0 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -22,10 +22,12 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.3.0", + "nvidia-cutlass-dsl==4.3.3", "torch", "einops", "typing_extensions", + "apache-tvm-ffi>=0.1.5,<0.2", + "torch-c-dlpack-ext", ] [project.optional-dependencies]