Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __call__(
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
# Assume all strides are divisible by 128 bits except the last stride
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s for s in t.stride[:-1]), t.stride[-1])
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1]), t.stride[-1])
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
self.varlen_q = (mCuSeqlensQ is not None)
self._setup_attributes()
Expand All @@ -401,7 +401,7 @@ def __call__(
TileScheduler = SingleTileScheduler
num_batch = mK.shape[0]

# Uses seqlen k, etc. since main bwd kernel's blocks are over n
# Uses seqlen k, etc. since main bwd kernel's blocks are over n
tile_sched_args = TileSchedulerArguments(
num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
num_head=num_head,
Expand All @@ -416,7 +416,7 @@ def __call__(
mCuSeqlensQ=mCuSeqlensK,
mSeqUsedQ=mSeqUsedK,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)

Expand Down Expand Up @@ -1000,7 +1000,7 @@ def epilogue(
num_head: cutlass.Int32,
batch_size: cutlass.Int32,
seqlen: SeqlenInfoQK,
d_head: cutlass.Int32,
d_head: cutlass.Int32,
d_head_v: cutlass.Int32
):
rdV = cute.make_fragment_like(acc_dV, self.dtype)
Expand Down
4 changes: 3 additions & 1 deletion flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ def __call__(
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s
cute.assume(s, divby=128 // t.element_type.width)
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
t.stride[-1],
Expand Down
5 changes: 2 additions & 3 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def __call__(
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width)
if s != 0
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def __call__(
new_stride = lambda t: (
*(
cute.assume(s, divby=128 // t.element_type.width)
if s != 0
if not isinstance(s, int) or s != 0
else s
for s in t.stride[:-1]
),
Expand Down Expand Up @@ -2482,4 +2482,3 @@ def warp_scheduler_barrier_arrive(self):
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
number_of_threads=2 * self.num_threads_per_warp_group,
)