Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f6880af
clean up and rebase for PR
reubenconducts Oct 16, 2025
e7121a3
add mask mod tests
reubenconducts Oct 16, 2025
864de11
add benchmarking files
reubenconducts Oct 16, 2025
d300e33
refactor for better style
reubenconducts Oct 16, 2025
9fbc2d4
remove extraneous csrc
reubenconducts Oct 19, 2025
b81eaa4
type hint buffers
reubenconducts Oct 20, 2025
e05ec82
Merge remote-tracking branch 'upstream/main' into rstern/flex-mask-mo…
reubenconducts Oct 20, 2025
5d5bb09
refactor: order of non/overlap and modify blocksparse producer to agr…
reubenconducts Oct 21, 2025
a17bb58
change variable name back to buffers
reubenconducts Oct 21, 2025
7c563ac
remove unnecessary variable in first_half_block
reubenconducts Oct 21, 2025
b5f7082
restore erroneous packgqa deletion
reubenconducts Oct 21, 2025
ab5c024
add blocksparsity and mask_mod asserts to interface.py
reubenconducts Oct 21, 2025
06820e8
fix rebase issues
reubenconducts Oct 21, 2025
db0ea95
Restore submodule and reset pointer to upstream/main
reubenconducts Oct 21, 2025
41ba160
rename cutlass.const_expr to const_expr
reubenconducts Oct 21, 2025
c6e0d6b
support fully masked m blocks (i.e. skipped tiles)
reubenconducts Oct 21, 2025
d28e6a8
remove outdated commented code
reubenconducts Oct 21, 2025
ee938b0
Merge remote-tracking branch 'upstream/main' into rstern/flex-mask-mod
reubenconducts Oct 22, 2025
c1f26dd
rename buffers -> aux_tensors, fix score_mod test in sm90 fwd
reubenconducts Oct 23, 2025
bbdff38
fix mask mod interface issues and tests
reubenconducts Oct 23, 2025
5b9961a
remove newline at end of file
reubenconducts Oct 23, 2025
c889d18
format with ruff
reubenconducts Oct 24, 2025
c08a851
format mask & sm100 with ruff
reubenconducts Oct 24, 2025
f3561a6
format more files with ruff
reubenconducts Oct 24, 2025
932d3ab
format barrier.py with ruff
reubenconducts Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions flash_attn/cute/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import llvm


@dsl_user_op
def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
state = llvm.inline_asm(
T.i32(),
Expand All @@ -18,8 +19,11 @@ def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
)
return cutlass.Int32(state)


@dsl_user_op
def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None:
def red_relaxed(
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
) -> None:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
Expand All @@ -31,8 +35,11 @@ def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
asm_dialect=llvm.AsmDialect.AD_ATT,
)


@dsl_user_op
def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None:
def red_release(
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
) -> None:
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
Expand All @@ -43,28 +50,22 @@ def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)



@cute.jit
def wait_eq(
lock_ptr : cute.Pointer,
thread_idx : int | Int32,
flag_offset : int,
val : Int32
) -> None:
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
flag_ptr = lock_ptr + flag_offset
if thread_idx == 0:
read_val = Int32(0)
while read_val != val:
read_val = ld_acquire(flag_ptr)


@cute.jit
def arrive_inc(
lock_ptr : cute.Pointer,
thread_idx : int | Int32,
flag_offset : int,
val : cutlass.Constexpr[Int32]
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
) -> None:
flag_ptr = lock_ptr + flag_offset
if thread_idx == 0:
red_release(flag_ptr, val)
# red_relaxed(flag_ptr, val)
# red_relaxed(flag_ptr, val)
36 changes: 17 additions & 19 deletions flash_attn/cute/benchmark_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from dataclasses import dataclass
import math
from pickle import FALSE
from typing import Any, Dict, Optional, Tuple

import cuda.bindings.driver as cuda
Expand Down Expand Up @@ -51,7 +50,7 @@ class BenchmarkConfig:
# Mask parameters
use_mask_mod: bool = True
mask_mod_name: str = "causal"
has_buffers: bool = mask_mod_name == "document"
has_aux_tensors: bool = mask_mod_name == "document"

# Sliding window parameter (used when mask_mod_name == "sliding_window")
window_size: int = 128
Expand Down Expand Up @@ -235,7 +234,6 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]:
dtype=torch.float32,
device=device,
)


tensors = {
"q": q.contiguous(),
Expand All @@ -244,10 +242,10 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]:
"out": out.contiguous(),
"lse": lse.contiguous(),
}

if config.use_learnable_sink:
learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device)

tensors["learnable_sink"] = learnable_sink.contiguous()

# Compute block sparsity when using mask_mod
Expand All @@ -256,14 +254,14 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]:
doc_id = random_doc_id_tensor(
config.batch_size, config.nheads, config.seqlen_q, device=device
)
tensors["buffers"] = [doc_id.contiguous()]
tensors["aux_tensors"] = [doc_id.contiguous()]
full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity(
config=self.config,
mask_mod_flex=self.mask_mod_flex,
device=device,
cu_seqlens_q=tensors.get("cu_seqlens_q"),
cu_seqlens_k=tensors.get("cu_seqlens_k"),
buffers=tensors.get("buffers"),
aux_tensors=tensors.get("aux_tensors"),
)

if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]):
Expand Down Expand Up @@ -329,7 +327,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
mma_pv_is_rs=config.mma_pv_is_rs,
mask_mod=self.mask_mod_cute,
Q_in_regs=False,
has_buffers=config.has_buffers,
has_aux_tensors=config.has_aux_tensors,
)

softmax_scale = 1.0 / math.sqrt(config.headdim)
Expand Down Expand Up @@ -405,14 +403,14 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
else None
)

if "buffers" in tensors:
buffers_cute = []
for i in range(len(tensors["buffers"])):
buf = from_dlpack(tensors["buffers"][i].detach(), assumed_align=4)
buffers_cute.append(buf.mark_layout_dynamic(leading_dim=2))
if "aux_tensors" in tensors:
aux_tensors_cute = []
for i in range(len(tensors["aux_tensors"])):
buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4)
aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2))

else:
buffers_cute = None
aux_tensors_cute = None

# Window parameters for is_local
window_left_cute = (
Expand Down Expand Up @@ -443,7 +441,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
full_block_idx_cute,
mask_block_cnt_cute,
mask_block_idx_cute,
buffers_cute,
aux_tensors_cute,
# None,
)

Expand All @@ -467,7 +465,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]
full_block_idx_cute,
mask_block_cnt_cute,
mask_block_idx_cute,
buffers_cute,
aux_tensors_cute,
# None,
)

Expand Down Expand Up @@ -496,7 +494,7 @@ def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float:
num_blocks = (config.seqlen_k + block_size - 1) // block_size
sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0
elif config.mask_mod_name == "document":
vals = tensors["buffers"][0]
vals = tensors["aux_tensors"][0]
val_mask = torch.ones_like(vals, dtype=torch.bool)
val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1]
total = torch.where(val_mask, vals.square(), 0).sum()
Expand Down Expand Up @@ -573,7 +571,7 @@ def benchmark(self) -> Dict[str, Any]:
torch.cuda.synchronize()

times.append(start.elapsed_time(end))

times_tensor = torch.tensor(times)
mean_time = times_tensor.mean().item()
std_time = times_tensor.std().item() if len(times) > 1 else 0.0
Expand Down Expand Up @@ -683,7 +681,7 @@ def _print_results(self, results: Dict[str, Any]):
# seqlen_k=192,
use_varlen=False,
use_mask_mod=True,
mask_mod_name="identity",
mask_mod_name="causal",
window_size=128, # Configurable window size for mask_mod
use_learnable_sink=False,
causal=False,
Expand Down
Loading