diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 619e0408cd4..a6d061b19b5 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -11,7 +11,7 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp -import cutlass.utils.ampere_helpers as sm80_utils_basic +import cutlass.utils as utils_basic from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import utils @@ -125,7 +125,7 @@ def can_implement( smem_usage_V = n_block_size * head_dim_v * 2 smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False return True diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d1b307acf02..b70da9a5264 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -16,7 +16,7 @@ import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup -import cutlass.utils.ampere_helpers as sm80_utils_basic +import cutlass.utils as utils_basic import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import ampere_helpers as sm80_utils @@ -127,7 +127,7 @@ def can_implement( smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False # Check if twice the block size is divisible by the number of threads