diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c63513faf8..0e60f835330 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,31 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ + files: ^flash_attn/cute/.*\.py$ + exclude: &cute_exclude | + (?x)^flash_attn/cute/( + __init__| + blackwell_helpers| + block_info| + copy_utils| + cute_dsl_utils| + fast_math| + flash_bwd| + flash_fwd| + flash_fwd_combine| + flash_fwd_sm100| + hopper_helpers| + interface| + mask| + mma_sm100_desc| + named_barrier| + pack_gqa| + pipeline| + seqlen_info| + testing| + tile_scheduler| + utils + )\.py$ - id: ruff-format - files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ + files: ^flash_attn/cute/.*\.py$ + exclude: *cute_exclude diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py index 839f407f75c..e3072d8ce85 100644 --- a/flash_attn/cute/ampere_helpers.py +++ b/flash_attn/cute/ampere_helpers.py @@ -8,11 +8,14 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: dtype_byte = cutlass.const_expr(dtype.width // 8) bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) - smem_k_block_size = cutlass.const_expr( - 128 - if bytes_per_row % 128 == 0 - else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) - ) // dtype_byte + smem_k_block_size = ( + cutlass.const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) swizzle_bits = ( 4 if smem_k_block_size == 128 @@ -22,7 +25,9 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.Compo return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), 0, - cute.make_ordered_layout((8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)), + cute.make_ordered_layout( + (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) + ), )