Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,20 +493,31 @@ def produce_block_sparse_loads_sm100(
pipeline_kv,
q_stage: cutlass.Constexpr,
q_producer_phase: Int32,
qhead_per_kvhead: cutlass.Constexpr,
):
"""SM100 entry point for sparse block iteration.

SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
simplified block processing that just calls producer_acquire without extras.

Args:
m_block: which tile of m we are processing
qhead_per_kvhead: Constexpr pack factor
"""
# NB: Compute unpacked index for sparse tensor access
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None]
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]

if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
Expand Down Expand Up @@ -574,15 +585,22 @@ def get_total_block_count(
batch_idx,
head_idx,
m_block,
qhead_per_kvhead: cutlass.Constexpr,
):
# NB: Convert packed m_block to unpacked for sparse tensor indexing
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
if const_expr(full_block_cnt is not None):
return (
mask_block_cnt[batch_idx, head_idx, m_block]
+ full_block_cnt[batch_idx, head_idx, m_block]
mask_block_cnt[batch_idx, head_idx, m_block_sparse]
+ full_block_cnt[batch_idx, head_idx, m_block_sparse]
)
else:
return mask_block_cnt[batch_idx, head_idx, m_block]
return mask_block_cnt[batch_idx, head_idx, m_block_sparse]


@cute.jit
Expand Down Expand Up @@ -717,16 +735,23 @@ def softmax_block_sparse_sm100(
mbar_P_full_2_offset: Int32,
q_stage: cutlass.Constexpr,
stage_idx: Int32,
check_m_boundary: bool = False,
check_m_boundary: bool,
qhead_per_kvhead: cutlass.Constexpr,
):
# Convert packed m_block to unpacked for sparse tensor indexing
if const_expr(qhead_per_kvhead != 1):
m_block_sparse = m_block // qhead_per_kvhead
else:
m_block_sparse = m_block

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None]
curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]

if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
Expand Down
8 changes: 5 additions & 3 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,7 @@ def load(
pipeline_kv,
self.q_stage,
q_producer_phase,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)


Expand Down Expand Up @@ -1366,7 +1367,7 @@ def mma(
process_tile = False

if const_expr(self.use_block_sparsity):
block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block)
block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
process_tile = block_iter_count > Int32(0)
else:
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
Expand Down Expand Up @@ -1674,7 +1675,7 @@ def softmax_loop(
softmax.reset()

if const_expr(self.use_block_sparsity):
tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block)
tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
has_work = tile_block_count > Int32(0)
else:
tile_block_count = n_block_max - n_block_min
Expand Down Expand Up @@ -1742,6 +1743,7 @@ def softmax_loop(
self.q_stage,
Int32(stage),
check_m_boundary,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)
if not empty_tile:
sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]
Expand Down Expand Up @@ -2034,7 +2036,7 @@ def correction_loop(
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage

if const_expr(self.use_block_sparsity):
total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block)
total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1)
has_work = total_block_count > Int32(0)
else:
total_block_count = n_block_max - n_block_min
Expand Down
25 changes: 11 additions & 14 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
import torch


@lru_cache(maxsize=None)
def _get_device_capability():
"""Cached device capability check."""
return torch.cuda.get_device_capability()[0]

import cuda.bindings.driver as cuda

import cutlass
Expand All @@ -55,6 +50,11 @@ def _get_device_capability():
get_block_sparse_expected_shapes_bwd,
)

@lru_cache(maxsize=None)
def _get_device_capability():
"""Cached device capability check."""
return torch.cuda.get_device_capability()[0]

def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

Expand Down Expand Up @@ -327,20 +327,18 @@ def _flash_attn_fwd(
raise NotImplementedError(
"mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
)
if pack_gqa:
raise NotImplementedError(
"mask_mod with aux_tensors is not yet supported with pack_gqa=True. This will be fixed in a future PR."
)

if use_block_sparsity:
if is_varlen:
raise NotImplementedError(
"Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
)
if pack_gqa:
raise NotImplementedError(
"Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR."
)
# NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
pack_gqa = False
# SM90 doesn't support pack_gqa + block_sparsity yet
if pack_gqa and compute_capability == 9:
pack_gqa = False
if is_split_kv:
raise NotImplementedError(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
Expand Down Expand Up @@ -506,7 +504,6 @@ def _flash_attn_fwd(
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)

_flash_attn_fwd.compile_cache[compile_key](
q,
k,
Expand Down
44 changes: 28 additions & 16 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@ def apply_mask(
for r in cutlass.range_constexpr(nrow):
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
row_for_mod = global_row_idx
head_idx_for_mod = head_idx
if const_expr(self.qhead_per_kvhead_packgqa != 1):
head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
row_for_seqlen = row_for_mod
if const_expr(wrap_aux_indices):
_, row_for_mod = divmod(global_row_idx, fastdiv_mods[0])
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])

for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][1]
Expand All @@ -156,7 +162,7 @@ def apply_mask(
_, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])

batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
mask_value = mask_mod(
Expand All @@ -168,7 +174,7 @@ def apply_mask(
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
if const_expr(mask_seqlen):
out_of_bounds = (global_row_idx >= self.seqlen_q) or (
out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
global_col_idx >= self.seqlen_k
)
if out_of_bounds:
Expand Down Expand Up @@ -346,26 +352,32 @@ def apply_mask_sm100(
and fastdiv_mods[1] is not None
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
row_coord_first = tScS_t2r[0][0]
global_row = row_coord_first + m_block * self.tile_m
if const_expr(self.qhead_per_kvhead_packgqa != 1):
mask_row = global_row // self.qhead_per_kvhead_packgqa
else:
mask_row = global_row
mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)

ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_row = row_coord + m_block * self.tile_m
global_col = col_coord + n_block * self.tile_n

if const_expr(self.qhead_per_kvhead_packgqa != 1):
head_offset = global_row % self.qhead_per_kvhead_packgqa
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
mask_row = global_row // self.qhead_per_kvhead_packgqa
else:
head_idx_for_mod = head_idx
mask_row = global_row

mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
global_col_for_mod = global_col
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])

head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
Expand All @@ -379,7 +391,7 @@ def apply_mask_sm100(
if const_expr(mask_seqlen):
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
if check_q_boundary:
acc_S[i] = -Float32.inf if global_row >= self.seqlen_q else acc_S[i]
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]

else: # Causal or local
causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
Expand Down
25 changes: 13 additions & 12 deletions flash_attn/cute/mask_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,22 @@ def cute_ima_mask(


def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"):
"""Generate synthetic document ids shared across heads."""
doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device)
for b in range(batch):
N = seqlen_q
max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1))))
n = random.randint(1, max_segments)
n = min(n, N)
cuts = sorted(random.sample(range(1, N), n - 1))
lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))]
base_doc_ids = torch.repeat_interleave(
torch.arange(len(lengths), device=device, dtype=torch.int32),
torch.tensor(lengths, device=device, dtype=torch.int32),
)

for h in range(nheads):
N = seqlen_q
max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1))))
n = random.randint(1, max_segments)
n = min(n, N)
cuts = sorted(random.sample(range(1, N), n - 1))
lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))]

doc_ids = []
for i, length in enumerate(lengths):
doc_ids += [i for _ in range(length)]

doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device)
doc_ids_tensor[b, h, :] = base_doc_ids
return doc_ids_tensor


Expand Down
15 changes: 10 additions & 5 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ def _run_mask_test(
# Determine nheads_kv based on mode
if kv_mode == "mha":
nheads_kv = nheads
pack_gqa = False
elif kv_mode == "gqa":
nheads_kv = nheads // 2
if COMPUTE_CAPABILITY != 10:
pytest.skip("pack_gqa requires SM100")
nheads_kv = nheads // 4
pack_gqa = True
elif kv_mode == "mqa":
nheads_kv = 1
pack_gqa = False
else:
raise ValueError(f"Unknown kv_mode: {kv_mode}")

Expand Down Expand Up @@ -211,10 +216,11 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
else:
sparse_tile_m = tile_m

block_mask_nheads = 1 if pack_gqa else nheads
bm = create_block_mask(
mask_mod_flex,
batch_size,
nheads,
block_mask_nheads,
seqlen_q,
seqlen_k,
device="cuda",
Expand Down Expand Up @@ -270,8 +276,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
learnable_sink=None,
m_block_size=tile_m,
n_block_size=tile_n,
num_threads=384,
pack_gqa=False,
pack_gqa=pack_gqa,
_compute_capability=None,
score_mod=None,
mask_mod=mask_mod_cute,
Expand Down Expand Up @@ -626,7 +631,7 @@ def test_static_masks(

@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE)
@pytest.mark.parametrize("nheads", [16])
@pytest.mark.parametrize("kv_mode", ["mha"])
@pytest.mark.parametrize("kv_mode", ["mha", "gqa"])
@pytest.mark.parametrize("headdim", [128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_block_sparsity", [True, False])
Expand Down