Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,16 @@ class SharedStorage:
min_blocks_per_mp=1,
)

def _generate_attention_mask_cls(self, window_size_left, window_size_right):
return partial(
AttentionMask,
self.tile_m,
self.tile_n * self.cta_group_size,
swap_AB=True,
window_size_left=window_size_left,
window_size_right=window_size_right,
)

@cute.kernel
def kernel(
self,
Expand Down Expand Up @@ -1413,13 +1423,8 @@ def kernel(
)
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)

AttentionMaskCls = partial(
AttentionMask,
self.tile_m,
self.tile_n * self.cta_group_size,
swap_AB=True,
window_size_left=window_size_left,
window_size_right=window_size_right,
AttentionMaskCls = self._generate_attention_mask_cls(
window_size_left, window_size_right
)
# EMPTY
# (15)
Expand Down
21 changes: 14 additions & 7 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,18 @@ class SharedStorage:
min_blocks_per_mp=1,
)

def _generate_attention_mask_cls(self, window_size_left, window_size_right):
return partial(
AttentionMask,
self.m_block_size,
self.n_block_size,
window_size_left=window_size_left,
window_size_right=window_size_right,
qhead_per_kvhead_packgqa=(
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
),
)

# GPU device kernel
@cute.kernel
def kernel(
Expand Down Expand Up @@ -1060,13 +1072,8 @@ def kernel(
blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None
),
)
AttentionMaskCls = partial(
AttentionMask,
self.m_block_size,
self.n_block_size,
window_size_left=window_size_left,
window_size_right=window_size_right,
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
AttentionMaskCls = self._generate_attention_mask_cls(
window_size_left, window_size_right
)
# Cluster wait before tensor memory alloc
pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)
Expand Down