diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 8211e01965e..763e824e55b 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -385,7 +385,7 @@ def __call__( for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) # Assume all strides are divisible by 128 bits except the last stride # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s for s in t.stride[:-1]), t.stride[-1]) + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1]), t.stride[-1]) mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] self.varlen_q = (mCuSeqlensQ is not None) self._setup_attributes() @@ -401,7 +401,7 @@ def __call__( TileScheduler = SingleTileScheduler num_batch = mK.shape[0] - # Uses seqlen k, etc. since main bwd kernel's blocks are over n + # Uses seqlen k, etc. since main bwd kernel's blocks are over n tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mK.shape[1], self.n_block_size), num_head=num_head, @@ -416,7 +416,7 @@ def __call__( mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, ) - + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -1000,7 +1000,7 @@ def epilogue( num_head: cutlass.Int32, batch_size: cutlass.Int32, seqlen: SeqlenInfoQK, - d_head: cutlass.Int32, + d_head: cutlass.Int32, d_head_v: cutlass.Int32 ): rdV = cute.make_fragment_like(acc_dV, self.dtype) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index ede18638a73..377a66a4385 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -354,7 +354,9 @@ def __call__( # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) new_stride = lambda t: ( *( - cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s + cute.assume(s, divby=128 // t.element_type.width) + if not isinstance(s, int) or s != 0 + else s for s in t.stride[:-1] ), t.stride[-1], diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 3ba52ce4540..c13cd267719 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -663,7 +663,7 @@ def __call__( new_stride = lambda t: ( *( cute.assume(s, divby=128 // t.element_type.width) - if s != 0 + if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1] ), @@ -1306,7 +1306,7 @@ def __call__( new_stride = lambda t: ( *( cute.assume(s, divby=128 // t.element_type.width) - if s != 0 + if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1] ), @@ -2482,4 +2482,3 @@ def warp_scheduler_barrier_arrive(self): barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, ) -