Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, warp
import cutlass.utils.ampere_helpers as sm80_utils_basic
import cutlass.utils as utils_basic

from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute import utils
Expand Down Expand Up @@ -125,7 +125,7 @@ def can_implement(
smem_usage_V = n_block_size * head_dim_v * 2
smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"]
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
if smem_usage > smem_capacity:
return False
return True
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cutlass.cute as cute
from cutlass import Float32, Int32, const_expr
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
import cutlass.utils.ampere_helpers as sm80_utils_basic
import cutlass.utils as utils_basic
import cutlass.utils.hopper_helpers as sm90_utils_basic

from flash_attn.cute import ampere_helpers as sm80_utils
Expand Down Expand Up @@ -127,7 +127,7 @@ def can_implement(
smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)
smem_usage = smem_usage_QV + smem_usage_K
# TODO: sm86 and sm89
smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"]
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
if smem_usage > smem_capacity:
return False
# Check if twice the block size is divisible by the number of threads
Expand Down