diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py index c07e3e94176..ecda0e273ad 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py @@ -21,6 +21,7 @@ from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned def _as_bshkrd_tensor( @@ -251,6 +252,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 6c9db87458f..885ae336f5f 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_attn.cute.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -2811,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2825,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index a95c9677b65..25d6a91de70 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_attn.cute.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -924,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1057,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1197,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1204,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1836,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1932,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1952,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2181,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2274,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 28087125f47..379cebc1905 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 2ebf338598c..764d7123681 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -155,8 +155,6 @@ def test_flash_attn_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") device = "cuda" # set seed seed = 0 @@ -551,10 +549,6 @@ def test_flash_attn_varlen_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") - if zero_lengths_q or zero_lengths_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support zero-length sequences yet") if not unpad_q or not unpad_kv: pytest.skip("SM100 head_dim=256 2CTA kernel does not support seqused_q/seqused_k mode yet (requires unpad_q=True and unpad_kv=True)") if (