Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion benchmarks/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func

try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
except ImportError:
flash_attn_func = None
flash_attn_varlen_func = None
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python
try:
Expand Down
78 changes: 1 addition & 77 deletions flash_attn/cute/fast_math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copyright (c) 2025, Tri Dao.

from typing import Tuple

import cutlass
import cutlass.cute as cute
from cutlass import Int32, Uint32
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import llvm
from cutlass import Int32


@cute.jit
Expand All @@ -23,75 +19,3 @@ def clz(x: Int32) -> Int32:
res = Int32(i)
done = True
return res


def find_log2(x: Int32) -> Int32:
a: Int32 = Int32(31 - clz(x))
return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.


@dsl_user_op
def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
return Uint32(
llvm.inline_asm(
T.i32(),
[Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
"mul.hi.u32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)


class FastDivmod:
def __init__(
self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None
):
self.divisor = divisor
self.multiplier = multipler
self.shift_right = shift_right
self._loc = loc

# called by host
@staticmethod
def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod":
"""Construct the FastDivmod object, in host code.
This precomputes some values based on the divisor and is computationally expensive.
"""
p = Uint32(31 + find_log2(divisor))
divisor_u32 = Uint32(divisor)
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
shift_right = Uint32(p - 32)
return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip)

@cute.jit
def div(self, dividend: Int32) -> Int32:
return (
Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
if self.divisor != 1
else dividend
)

def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
quotient = self.div(dividend)
remainder = dividend - quotient * self.divisor
return quotient, remainder

def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [self.divisor, self.multiplier, self.shift_right]:
obj_values = cutlass.extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values

def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip(
[self.divisor, self.multiplier, self.shift_right], self._values_pos
):
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return FastDivmod(*(tuple(obj_list)), loc=self._loc)
24 changes: 16 additions & 8 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ def __call__(
assert mdK_semaphore is not None
assert mdV_semaphore is not None
mdK_semaphore, mdV_semaphore = [
utils.select(t, mode=semaphore_transpose)
for t in (mdK_semaphore, mdV_semaphore)
utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore)
]
else:
mdK_semaphore = None
Expand Down Expand Up @@ -562,7 +561,7 @@ def __call__(
cute.size(mQ.shape[2]), # num_heads = num_query_heads
cute.size(mK.shape[3]),
1, # num_splits
cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k
cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is all ruff formatting

mQ.shape[1],
mV.shape[1],
total_q=cute.size(mQ.shape[0]),
Expand Down Expand Up @@ -1905,7 +1904,9 @@ def compute_loop(

if const_expr(not self.use_smem_dS_for_mma_dK):
cute.arch.fence_view_async_tmem_store()
cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
self.compute_sync_barrier.arrive_and_wait()

# with cute.arch.elect_one():
Expand Down Expand Up @@ -2032,7 +2033,7 @@ def dQacc_reduce(
gdQaccum = cute.flat_divide(
gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,)
)

if const_expr(self.deterministic):
mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]

Expand Down Expand Up @@ -2068,12 +2069,17 @@ def dQacc_reduce(
if const_expr(self.spt):
n_block_max_for_m_block = min(
n_block_global_max,
cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n)
cute.ceil_div(
(m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q,
self.tile_n,
),
)
lock_value = n_block_max_for_m_block - 1 - n_block
else:
lock_value = n_block
barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value)
barrier.wait_eq(
mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value
)
self.reduce_sync_barrier.arrive_and_wait()
# Copy from shared memory to global memory
if is_tma_warp:
Expand Down Expand Up @@ -2101,7 +2107,9 @@ def dQacc_reduce(
# semaphore release for prior m_block
if const_expr(self.deterministic and stage == 0 and delay_semaphore_release):
if m_block > m_block_min:
barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1)
barrier.arrive_inc(
mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1
)

# semaphore release
# NOTE: arrive_inc calls red_release which issues membar
Expand Down
10 changes: 5 additions & 5 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
SingleTileVarlenScheduler,
ParamsBase,
)
from flash_attn.cute.fast_math import FastDivmod
from cutlass.cute import FastDivmodDivisor


class FlashAttentionForwardBase:
Expand Down Expand Up @@ -692,8 +692,8 @@ def __call__(
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_q_divmod = FastDivmod.create(seqlen_q)
seqlen_k_divmod = FastDivmod.create(seqlen_k)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)

self.kernel(
Expand Down Expand Up @@ -1503,8 +1503,8 @@ def __call__(
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_q_divmod = FastDivmod.create(seqlen_q)
seqlen_k_divmod = FastDivmod.create(seqlen_k)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)

self.kernel(
Expand Down
20 changes: 10 additions & 10 deletions flash_attn/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from cutlass import Float32, Int32, const_expr

from flash_attn.cute import utils
from flash_attn.cute.fast_math import FastDivmod
from flash_attn.cute.seqlen_info import SeqlenInfo
from cutlass.cute import FastDivmodDivisor


class FlashAttentionForwardCombine:
Expand Down Expand Up @@ -257,9 +257,9 @@ class SharedStorage:
num_head = mO_partial.shape[3]
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)

# Create FastDivmod objects for efficient division
seqlen_divmod = FastDivmod.create(seqlen)
head_divmod = FastDivmod.create(num_head)
# Create FastDivmodDivisor objects for efficient division
seqlen_divmod = FastDivmodDivisor(seqlen)
head_divmod = FastDivmodDivisor(num_head)

grid_dim = (
cute.ceil_div(seqlen * num_head, self.m_block_size),
Expand Down Expand Up @@ -311,8 +311,8 @@ def kernel(
gmem_tiled_copy_O: cute.TiledCopy,
gmem_tiled_copy_LSE: cute.TiledCopy,
s2r_tiled_copy_LSE: cute.TiledCopy,
seqlen_divmod: FastDivmod,
head_divmod: FastDivmod,
seqlen_divmod: FastDivmodDivisor,
head_divmod: FastDivmodDivisor,
varlen: cutlass.Constexpr[bool],
):
# Thread and block indices
Expand Down Expand Up @@ -380,9 +380,9 @@ def kernel(
mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
idx = m_block * self.m_block_size + mi
if idx < max_idx:
# Calculate actual sequence position and head using FastDivmod
# Calculate actual sequence position and head using FastDivmodDivisor
if const_expr(not varlen):
head_idx, m_idx = seqlen_divmod.divmod(idx)
head_idx, m_idx = divmod(idx, seqlen_divmod)
else:
head_idx = idx // seqlen
m_idx = idx - head_idx * seqlen
Expand Down Expand Up @@ -420,7 +420,7 @@ def kernel(
mi = tOcO[0, m, 0][0] # m coordinate
idx = m_block * self.m_block_size + mi
if const_expr(not varlen):
tOhidx[m], tOmidx[m] = seqlen_divmod.divmod(idx)
tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
else:
tOhidx[m] = idx // seqlen
tOmidx[m] = idx - tOhidx[m] * seqlen
Expand Down Expand Up @@ -536,7 +536,7 @@ def kernel(
idx = m_block * self.m_block_size + mi
if idx < max_idx:
if const_expr(not varlen):
head_idx, m_idx = seqlen_divmod.divmod(idx)
head_idx, m_idx = divmod(idx, seqlen_divmod)
else:
head_idx = idx // seqlen
m_idx = idx - head_idx * seqlen
Expand Down
10 changes: 5 additions & 5 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from flash_attn.cute.pack_gqa import PackGQA
from flash_attn.cute import mma_sm100_desc as sm100_desc
from flash_attn.cute import blackwell_helpers as sm100_utils
from flash_attn.cute.fast_math import FastDivmod
from cutlass.cute import FastDivmodDivisor
from flash_attn.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
Expand Down Expand Up @@ -659,8 +659,8 @@ class SharedStorage:
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_q_divmod = FastDivmod.create(seqlen_q)
seqlen_k_divmod = FastDivmod.create(seqlen_k)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)

self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
Expand Down Expand Up @@ -1190,7 +1190,7 @@ def load(
mPageTable,
mK,
mV,
FastDivmod.create(page_size),
FastDivmodDivisor(page_size),
batch_idx,
head_idx_kv,
tidx,
Expand Down Expand Up @@ -2660,7 +2660,7 @@ def apply_score_mod(

if cutlass.const_expr(aux_tensors is not None):
seqlen_q_divmod, _ = fastdiv_mods
_, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical)
_, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod)

apply_score_mod_inner(
tSrS_t2r,
Expand Down
8 changes: 4 additions & 4 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def apply_mask(
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
row_for_mod = global_row_idx
if const_expr(wrap_aux_indices):
_, row_for_mod = fastdiv_mods[0].divmod(global_row_idx)
_, row_for_mod = divmod(global_row_idx, fastdiv_mods[0])

for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][1]
# Convert to absolute column index
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
col_for_mod = global_col_idx
if const_expr(wrap_aux_indices):
_, col_for_mod = fastdiv_mods[1].divmod(global_col_idx)
_, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])

batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
Expand Down Expand Up @@ -357,7 +357,7 @@ def apply_mask_sm100(
mask_row = global_row
mask_row_for_mod = mask_row
if const_expr(wrap_aux_indices):
_, mask_row_for_mod = fastdiv_mods[0].divmod(mask_row)
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)

ncol = const_expr(cute.size(tScS_t2r.shape))
Expand All @@ -366,7 +366,7 @@ def apply_mask_sm100(
global_col = col_coord + n_block * self.tile_n
global_col_for_mod = global_col
if const_expr(wrap_aux_indices):
_, global_col_for_mod = fastdiv_mods[1].divmod(global_col)
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
Expand Down
22 changes: 17 additions & 5 deletions flash_attn/cute/paged_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from cutlass import Int32, const_expr

from flash_attn.cute import utils
from flash_attn.cute.fast_math import FastDivmod
from flash_attn.cute.cute_dsl_utils import ParamsBase
from cutlass.cute import FastDivmodDivisor


@dataclass
Expand All @@ -18,7 +18,7 @@ class PagedKVManager(ParamsBase):
mV_paged: cute.Tensor
thread_idx: Int32

page_size_divmod: FastDivmod
page_size_divmod: FastDivmodDivisor
seqlen_k: Int32
leftpad_k: Int32
n_block_size: Int32
Expand All @@ -42,7 +42,7 @@ def create(
mPageTable: cute.Tensor,
mK_paged: cute.Tensor,
mV_paged: cute.Tensor,
page_size_divmod: FastDivmod,
page_size_divmod: FastDivmodDivisor,
bidb: Int32,
bidh: Int32,
thread_idx: Int32,
Expand Down Expand Up @@ -118,7 +118,7 @@ def load_page_table(self, n_block: Int32):
row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row
row_idx = n_block * self.n_block_size + row

page_idx, page_offset = self.page_size_divmod.divmod(row_idx + self.leftpad_k)
page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)

is_valid = (
(i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size
Expand Down Expand Up @@ -173,4 +173,16 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
)
elif const_expr(K_or_V == "V"):
# Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway.
tXsX[None, m, None].fill(0)
fill_swizzled(tXsX[None, m, None], 0)


@cutlass.dsl_user_op
def fill_swizzled(tensor, value: cutlass.Numeric, *, loc=None, ip=None) -> None:
"""Fill tensor with a constant value.

Fills all elements of the tensor with the specified value, assuming static size
and supported memory space.
"""
rTmp = cute.make_rmem_tensor_like(tensor, tensor.element_type)
rTmp.fill(value)
cute.autovec_copy(rTmp, tensor)
Loading