diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 11db2dab563..81462e50afd 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1010,6 +1010,16 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.tile_m, + self.tile_n * self.cta_group_size, + swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + @cute.kernel def kernel( self, @@ -1413,13 +1423,8 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - AttentionMaskCls = partial( - AttentionMask, - self.tile_m, - self.tile_n * self.cta_group_size, - swap_AB=True, - window_size_left=window_size_left, - window_size_right=window_size_right, + AttentionMaskCls = self._generate_attention_mask_cls( + window_size_left, window_size_right ) # EMPTY # (15) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 55a92f690bd..dc99022c2a5 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -772,6 +772,18 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + ) + # GPU device kernel @cute.kernel def kernel( @@ -1060,13 +1072,8 @@ def kernel( blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None ), ) - AttentionMaskCls = partial( - AttentionMask, - self.m_block_size, - self.n_block_size, - window_size_left=window_size_left, - window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + AttentionMaskCls = self._generate_attention_mask_cls( + window_size_left, window_size_right ) # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)