From fce987086827e9754f46363d3cb83ad6db8823d7 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 28 Jan 2026 16:46:49 +0000 Subject: [PATCH 01/10] clean up and add more vectorized tests --- flash_attn/cute/cute_dsl_utils.py | 18 +++ flash_attn/cute/flash_fwd_sm100.py | 13 +- flash_attn/cute/interface.py | 4 +- flash_attn/cute/utils.py | 12 ++ tests/cute/score_mod_definitions.py | 84 ++++++++++++ tests/cute/test_score_mod.py | 199 +++++++++++++++++++--------- 6 files changed, 261 insertions(+), 69 deletions(-) diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 9d2f7aa739b..86c05dd3262 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -152,6 +152,24 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena return tensor.mark_layout_dynamic(leading_dim=leading_dim) +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: """Return tuple of bools indicating which dims have stride=0 (broadcast). diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 886d02632a5..7acc0daed91 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -131,10 +131,15 @@ def __init__( ) self.score_mod = score_mod self.mask_mod = mask_mod - if cutlass.const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) + # if hasattr(score_mod, "__vec_size__"): + # self.vec_size: cutlass.Constexpr = score_mod.__vec_size__ + # elif cutlass.const_expr(has_aux_tensors): + # self.vec_size: cutlass.Constexpr = 1 + # else: + # self.vec_size: cutlass.Constexpr = 2 # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8d936602e31..cfeb5b48cd8 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -41,7 +41,7 @@ from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import to_cute_tensor +from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess @@ -433,7 +433,7 @@ def _flash_attn_fwd( cute_aux_tensors = None if aux_tensors is not None: - cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index f2383e89415..251185edfcd 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -30,6 +30,12 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: if hasattr(func, "__cute_hash__"): return func.__cute_hash__ + # __vec_size__ is attr of @cute.jitted mod + if hasattr(func, "__vec_size__"): + vec_size = func.__vec_size__ + else: + vec_size = None + # Unwrap decorated functions (e.g., cute.jit wrappers). if hasattr(func, "__wrapped__"): base_func = func.__wrapped__ @@ -37,6 +43,10 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: return base_func.__cute_hash__ func = base_func + # if base func has __vec_size__, overwrite + if hasattr(func, "__vec_size__"): + vec_size = func.__vec_size__ + try: data = inspect.getsource(func).encode() except (OSError, TypeError): @@ -52,6 +62,8 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: cell_value = cell.cell_contents hasher.update(repr(cell_value).encode()) + hasher.update(str(vec_size).encode()) + hash = hasher.hexdigest() if set_cute_hash: diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py index be6333a6448..f131bf2ae09 100644 --- a/tests/cute/score_mod_definitions.py +++ b/tests/cute/score_mod_definitions.py @@ -15,12 +15,28 @@ def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa +@cute.jit +def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa + + @cute.jit def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask = operator.ge(q_idx, kv_idx) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) +@cute.jit +def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean) + kv_idx0 = kv_idx[0] + q_idx0 = q_idx[0] + for i in cutlass.range_constexpr(cute.size(mask.shape)): + mask[i] = q_idx0 >= kv_idx0 + i + mask_ssa = mask.load() + return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + @cute.jit def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -28,6 +44,18 @@ def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa + abs_diff.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.rante_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return tSrS_ssa + abs_diff.load().to(cutlass.Float32) + + @cute.jit def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -36,10 +64,25 @@ def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + scaled.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_x2_vectorized( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff_x2[i] = mlir_math.absi(diffi) * 2 + return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32) + + @cute.jit def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa * cute.full_like(tSrS_ssa, 2) +score_mod_times_two_vectorized = score_mod_times_two @cute.jit def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -53,6 +96,21 @@ def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tens abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) return score - slope * abs_diff +@cute.jit +def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + score = tSrS_ssa.to(cutlass.Float32) + slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) + slope = cute.math.exp2( + slope_exp.to(cutlass.Float32) + * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) + ) + diff0 = q_idx[0] - kv_idx[0] + abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) + for i in cutlass.range_constexpr(cute.size(abs_diff.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return score - slope * abs_diff.load().to(cutlass.Float32) + @cute.jit def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -88,6 +146,16 @@ def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux bias_val = (bias_frag.load()).to(cutlass.Float32) return tSrS_ssa + bias_val +@cute.jit +def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + batch_bias = aux_tensors[0] + dtype = batch_bias.element_type + b_idx0 = b_idx[0] + bias_frag = cute.make_rmem_tensor(1, dtype) + bias_frag[0] = batch_bias[b_idx0] + bias_val = (bias_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias_val + @cute.jit def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -109,6 +177,22 @@ def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + head_val + pos_val +@cute.jit +def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + dtype = head_bias.element_type + + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_idx[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_idx[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + # ============================================================================= # Score_mod functions that use global indices diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 11efcc8cdbc..b2a5ae21470 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -23,6 +23,16 @@ score_mod_batch_bias as score_mod_10, score_mod_dual_buffer as score_mod_11, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized as score_mod_1_vectorized, + score_mod_causal_vectorized as score_mod_2_vectorized, + score_mod_rel_bias as score_mod_3_vectorized, + score_mod_rel_bias_x2_vectorized as score_mod_4_vectorized, + score_mod_times_two_vectorized as score_mod_5_vectorized, + score_mod_alibi_vectorized as score_mod_6_vectorized, + score_mod_batch_bias_vectorized as score_mod_10_vectorized, + score_mod_dual_buffer_vectorized as score_mod_11_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -59,6 +69,21 @@ (score_mod_11, dual_buffer_bias), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED = [ + (score_mod_1, score_mod_1_vectorized), + (score_mod_2, score_mod_2_vectorized), + (score_mod_3, score_mod_3_vectorized), + (score_mod_4, score_mod_4_vectorized), + (score_mod_5, score_mod_5_vectorized), + (score_mod_6, score_mod_6_vectorized), +] + +TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED = [ + (score_mod_10, score_mod_10_vectorized), + (score_mod_11, score_mod_11_vectorized), +] + SEQLEN_CONFIGS = [ (1, 1), (64, 128), @@ -92,12 +117,8 @@ def create_tensors( return q, k, v -def run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False -) -> torch.Tensor: - q_transposed, k_transposed, v_transposed = map( - lambda x: x.transpose(1, 2), (q, k, v) - ) +def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map(lambda x: x.transpose(1, 2), (q, k, v)) out = torch.empty_like(q_transposed) _flash_attn_fwd( q_transposed, @@ -116,9 +137,7 @@ def run_cute_flash( def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - return flex_attention( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @@ -174,6 +193,40 @@ def test_cute_vs_flex_attention( ) +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_VECTORIZED) +def test_cute_score_mod_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, +): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) + + for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) + + assert torch.equal(out, out_ref) + + @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -214,9 +267,7 @@ def test_cute_vs_flex_attention_with_aux_tensors( out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa - ) + out_cute = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -247,19 +298,61 @@ def test_cute_vs_flex_attention_with_aux_tensors( ) -def _generate_block_kvcache( - seqlen_k, page_size, batch_size, nheads_k, d, device, dtype +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED) +def test_cute_score_mod_with_aux_tensors_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, ): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + batch_size = 2 + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [buffer] + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + assert head_bias.shape == (num_q_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) + + for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) + + assert torch.equal(out, out_ref) + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): import math from einops import rearrange num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + k_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) + v_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -321,12 +414,8 @@ def test_score_mod_with_paged_kvcache( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -342,9 +431,7 @@ def test_score_mod_with_paged_kvcache( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) from einops import rearrange @@ -426,9 +513,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -478,12 +563,8 @@ def test_score_mod_with_paged_kvcache_aux_tensors( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -499,9 +580,7 @@ def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 @@ -595,9 +674,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -628,7 +705,7 @@ def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info @cute.jit def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). - + At unmasked positions (q_idx >= kv_idx), grad passes through. At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. """ @@ -678,7 +755,9 @@ def run_cute_flash_bwd( v_t = v.transpose(1, 2) out, lse = _flash_attn_fwd( - q_t, k_t, v_t, + q_t, + k_t, + v_t, return_lse=True, score_mod=cute_score_mod, aux_tensors=aux_tensors, @@ -688,8 +767,12 @@ def run_cute_flash_bwd( grad_out = torch.randn_like(out) dq, dk, dv = _flash_attn_bwd( - q_t, k_t, v_t, - out, grad_out, lse, + q_t, + k_t, + v_t, + out, + grad_out, + lse, score_mod=cute_score_mod, score_mod_bwd=cute_score_mod_bwd, aux_tensors=aux_tensors, @@ -718,9 +801,7 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): v = v.requires_grad_(True) compiled_flex = torch.compile(flex_attention) - out = compiled_flex( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + out = compiled_flex(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out) return out, dq, dk, dv @@ -755,15 +836,11 @@ def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_ seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) - out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( - q, k, v, cute_fwd, cute_bwd - ) + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any(), "dQ contains NaN" assert not torch.isnan(dk_cute).any(), "dK contains NaN" @@ -839,9 +916,7 @@ def test_cute_vs_flex_attention_backward_with_aux( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() @@ -901,9 +976,7 @@ def test_cute_vs_flex_attention_backward_pack_gqa( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() From cf14ef123f010567c5368abd1ac3b3eedd59e41d Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 28 Jan 2026 17:11:00 +0000 Subject: [PATCH 02/10] remove commented out change --- flash_attn/cute/flash_fwd_sm100.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7acc0daed91..82d091f199f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -134,12 +134,6 @@ def __init__( self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) - # if hasattr(score_mod, "__vec_size__"): - # self.vec_size: cutlass.Constexpr = score_mod.__vec_size__ - # elif cutlass.const_expr(has_aux_tensors): - # self.vec_size: cutlass.Constexpr = 1 - # else: - # self.vec_size: cutlass.Constexpr = 2 # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False From 85566ec2b9c9766596f04de15588dfcdfbf96268 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Thu, 29 Jan 2026 00:34:16 +0000 Subject: [PATCH 03/10] fix typo --- tests/cute/score_mod_definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py index f131bf2ae09..aaa3664abf0 100644 --- a/tests/cute/score_mod_definitions.py +++ b/tests/cute/score_mod_definitions.py @@ -50,7 +50,7 @@ def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_ kv_idx0 = kv_idx[0] diff0 = q_idx0 - kv_idx0 abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) - for i in cutlass.rante_constexpr(cute.size(kv_idx.shape)): + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): diffi = diff0 - i abs_diff[i] = mlir_math.absi(diffi) return tSrS_ssa + abs_diff.load().to(cutlass.Float32) From 6d9ef84c0ea190768fd7235eb4e81f18a81eb2e8 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Fri, 30 Jan 2026 22:22:58 +0000 Subject: [PATCH 04/10] add aux tensor alignment to compile key --- flash_attn/cute/cute_dsl_utils.py | 11 +++++++++++ flash_attn/cute/interface.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 86c05dd3262..ec750e8179b 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -170,6 +170,17 @@ def to_cute_aux_tensor(t, enable_tvm_ffi=True): ) +def get_aux_tensor_metadata(aux_tensors): + return tuple( + ( + getattr(t, "__assumed_align__", 0), + getattr(t, "__leading_dim__", -1), + hasattr(t, "__leading_dim__"), + ) + for t in aux_tensors + ) + + def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: """Return tuple of bools indicating which dims have stride=0 (broadcast). diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index cfeb5b48cd8..ef2df23e448 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -41,7 +41,7 @@ from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor +from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess @@ -368,7 +368,11 @@ def _flash_attn_fwd( block_size=(m_block_size, n_block_size), q_stage=q_stage, ) - + if aux_tensors is not None: + aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) + else: + aux_tensor_metadata = None + compile_key = ( dtype, head_dim, @@ -379,7 +383,7 @@ def _flash_attn_fwd( mask_mod_hash, use_block_sparsity, block_sparse_broadcast_pattern, - len(aux_tensors) if aux_tensors is not None else 0, + aux_tensor_metadata, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, @@ -432,6 +436,7 @@ def _flash_attn_fwd( sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None + aux_tensor_metadata = None if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] From 01d228a8a3b38ebd05da9d1912bc88e54f9e4e60 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 4 Feb 2026 22:13:47 +0000 Subject: [PATCH 05/10] add varlen score mod vec tests --- tests/cute/test_score_mod_varlen.py | 165 ++++++++++++++++++++++------ 1 file changed, 134 insertions(+), 31 deletions(-) diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 7cca7f2aa0a..22dddad393f 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -28,6 +28,16 @@ score_mod_stress_xor_pattern, score_mod_times_two, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized, + score_mod_causal_vectorized, + score_mod_rel_bias as score_mod_rel_bias_vectorized, + score_mod_rel_bias_x2_vectorized, + score_mod_times_two_vectorized, + score_mod_alibi_vectorized, + score_mod_batch_bias_vectorized, + score_mod_dual_buffer_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -77,6 +87,17 @@ (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED_NO_GLOBAL = [ + (score_mod_identity, score_mod_identity_vectorized, None), + (score_mod_causal, score_mod_causal_vectorized, None), + (score_mod_rel_bias, score_mod_rel_bias_vectorized, None), + (score_mod_rel_bias_x2, score_mod_rel_bias_x2_vectorized, None), + (score_mod_times_two, score_mod_times_two_vectorized, None), + (score_mod_alibi, score_mod_alibi_vectorized, None), + (score_mod_batch_bias, score_mod_batch_bias_vectorized, "batch"), + (score_mod_dual_buffer, score_mod_dual_buffer_vectorized, "dual_buffer"), +] # (cute_score_mod, eager_factory, aux_type, requires_global) # aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" # requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) @@ -116,39 +137,39 @@ ] SEQLEN_CONFIGS = [ - ([1], [1]), - ([1, 1], [1, 1]), - ([2, 3], [2, 3]), - ([8, 16], [8, 16]), - ([32, 32], [32, 32]), - ([64, 128], [64, 128]), - ([64, 56, 128], [64, 56, 128]), - ([256, 512], [256, 512]), - ([113, 203], [113, 203]), - ([239, 1], [239, 1]), - ([64], [64]), - ([128, 128], [128, 128]), - ([32, 32, 32, 32], [32, 32, 32, 32]), - ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), - ([1, 1024], [1, 1024]), - ([1024, 1], [1024, 1]), - ([1, 256, 1], [1, 256, 1]), - ([256, 1, 256], [256, 1, 256]), - ([17, 33, 65], [17, 33, 65]), - ([64, 128], [32, 64]), - ([100, 100], [50, 50]), + # ([1], [1]), + # ([1, 1], [1, 1]), + # ([2, 3], [2, 3]), + # ([8, 16], [8, 16]), + # ([32, 32], [32, 32]), + # ([64, 128], [64, 128]), + # ([64, 56, 128], [64, 56, 128]), + # ([256, 512], [256, 512]), + # ([113, 203], [113, 203]), + # ([239, 1], [239, 1]), + # ([64], [64]), + # ([128, 128], [128, 128]), + # ([32, 32, 32, 32], [32, 32, 32, 32]), + # ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), + # ([1, 1024], [1, 1024]), + # ([1024, 1], [1024, 1]), + # ([1, 256, 1], [1, 256, 1]), + # ([256, 1, 256], [256, 1, 256]), + # ([17, 33, 65], [17, 33, 65]), + # ([64, 128], [32, 64]), + # ([100, 100], [50, 50]), ([256, 512, 256], [128, 256, 128]), ([2, 1], [16384, 32 * 1024]), - ([1, 1], [128 * 1024] * 2), - ([2, 1], [8192, 8192]), - ([1, 3], [8192, 8192]), - ([3, 3], [8192, 8192]), - ([128, 128], [8192, 8192]), - ([2, 2, 2], [8 * 1024] * 3), - ([2, 1], [1024 * 32, 16384]), - ([1, 2], [1024 * 32, 16384]), - ([1, 1, 1], [128 * 1024] * 3), - ([1, 1, 1], [256 * 1024] * 3), + # ([1, 1], [128 * 1024] * 2), + # ([2, 1], [8192, 8192]), + # ([1, 3], [8192, 8192]), + # ([3, 3], [8192, 8192]), + # ([128, 128], [8192, 8192]), + # ([2, 2, 2], [8 * 1024] * 3), + # ([2, 1], [1024 * 32, 16384]), + # ([1, 2], [1024 * 32, 16384]), + # ([1, 1, 1], [128 * 1024] * 3), + # ([1, 1, 1], [256 * 1024] * 3), ] # ============================================================================= @@ -488,6 +509,88 @@ def test_varlen_with_score_mod( cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_vec_tuple", TEST_PAIRS_VECTORIZED_NO_GLOBAL) +def test_varlen_with_score_mod_vectorized( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_tuple, +): + """Tests equality between original and vectorized versions of score mods""" + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod, aux_type = score_mod_vec_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + out_ref = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + # for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for vec_size in [1, 2, 4]: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, + k, + v, + cute_vectorized_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + assert torch.equal(out, out_ref) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False]) From 782f9bd00b61744d121ce2b1a7e094ed5563d7a1 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Wed, 4 Feb 2026 22:17:02 +0000 Subject: [PATCH 06/10] uncomment test configs --- tests/cute/test_score_mod_varlen.py | 65 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 22dddad393f..08a5695c050 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -137,39 +137,39 @@ ] SEQLEN_CONFIGS = [ - # ([1], [1]), - # ([1, 1], [1, 1]), - # ([2, 3], [2, 3]), - # ([8, 16], [8, 16]), - # ([32, 32], [32, 32]), - # ([64, 128], [64, 128]), - # ([64, 56, 128], [64, 56, 128]), - # ([256, 512], [256, 512]), - # ([113, 203], [113, 203]), - # ([239, 1], [239, 1]), - # ([64], [64]), - # ([128, 128], [128, 128]), - # ([32, 32, 32, 32], [32, 32, 32, 32]), - # ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), - # ([1, 1024], [1, 1024]), - # ([1024, 1], [1024, 1]), - # ([1, 256, 1], [1, 256, 1]), - # ([256, 1, 256], [256, 1, 256]), - # ([17, 33, 65], [17, 33, 65]), - # ([64, 128], [32, 64]), - # ([100, 100], [50, 50]), + ([1], [1]), + ([1, 1], [1, 1]), + ([2, 3], [2, 3]), + ([8, 16], [8, 16]), + ([32, 32], [32, 32]), + ([64, 128], [64, 128]), + ([64, 56, 128], [64, 56, 128]), + ([256, 512], [256, 512]), + ([113, 203], [113, 203]), + ([239, 1], [239, 1]), + ([64], [64]), + ([128, 128], [128, 128]), + ([32, 32, 32, 32], [32, 32, 32, 32]), + ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), + ([1, 1024], [1, 1024]), + ([1024, 1], [1024, 1]), + ([1, 256, 1], [1, 256, 1]), + ([256, 1, 256], [256, 1, 256]), + ([17, 33, 65], [17, 33, 65]), + ([64, 128], [32, 64]), + ([100, 100], [50, 50]), ([256, 512, 256], [128, 256, 128]), ([2, 1], [16384, 32 * 1024]), - # ([1, 1], [128 * 1024] * 2), - # ([2, 1], [8192, 8192]), - # ([1, 3], [8192, 8192]), - # ([3, 3], [8192, 8192]), - # ([128, 128], [8192, 8192]), - # ([2, 2, 2], [8 * 1024] * 3), - # ([2, 1], [1024 * 32, 16384]), - # ([1, 2], [1024 * 32, 16384]), - # ([1, 1, 1], [128 * 1024] * 3), - # ([1, 1, 1], [256 * 1024] * 3), + ([1, 1], [128 * 1024] * 2), + ([2, 1], [8192, 8192]), + ([1, 3], [8192, 8192]), + ([3, 3], [8192, 8192]), + ([128, 128], [8192, 8192]), + ([2, 2, 2], [8 * 1024] * 3), + ([2, 1], [1024 * 32, 16384]), + ([1, 2], [1024 * 32, 16384]), + ([1, 1, 1], [128 * 1024] * 3), + ([1, 1, 1], [256 * 1024] * 3), ] # ============================================================================= @@ -576,8 +576,7 @@ def test_varlen_with_score_mod_vectorized( cu_seqlens_k=cu_seqlens_k, ) - # for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: - for vec_size in [1, 2, 4]: + for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( q, From 11447a4bf96159325ca9b14404b8591f10d66463 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Thu, 5 Feb 2026 16:25:30 +0000 Subject: [PATCH 07/10] sm90 fwd --- flash_attn/cute/flash_fwd.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 9eaccda41bc..bba612bc4cb 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -113,10 +113,9 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) @staticmethod def can_implement( From ab02dd3a771c08f230b934cf42d01533adfa8025 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 9 Feb 2026 22:01:54 +0000 Subject: [PATCH 08/10] update hash callable --- flash_attn/cute/utils.py | 74 ++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 251185edfcd..4d0ad0ccf62 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -17,36 +17,11 @@ import quack.activation +_MIXER_ATTRS = ("__vec_size__",) -def hash_callable(func: Callable, set_cute_hash=True) -> str: - """Hash a callable based on the source code or bytecode and closure values. - - Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` - attribute, that value is returned immediately. Code-generation backends such - as Inductor can set this attribute to avoid expensive runtime hashing. - - set_cute_hash: whether or not to set func.__cute_hash__ if not present - """ - if hasattr(func, "__cute_hash__"): - return func.__cute_hash__ - - # __vec_size__ is attr of @cute.jitted mod - if hasattr(func, "__vec_size__"): - vec_size = func.__vec_size__ - else: - vec_size = None - - # Unwrap decorated functions (e.g., cute.jit wrappers). - if hasattr(func, "__wrapped__"): - base_func = func.__wrapped__ - if hasattr(base_func, "__cute_hash__"): - return base_func.__cute_hash__ - func = base_func - - # if base func has __vec_size__, overwrite - if hasattr(func, "__vec_size__"): - vec_size = func.__vec_size__ +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" try: data = inspect.getsource(func).encode() except (OSError, TypeError): @@ -54,22 +29,41 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: data = func.__code__.co_code else: data = repr(func).encode() - hasher = hashlib.sha256(data) - if hasattr(func, "__closure__") and func.__closure__ is not None: - for idx, cell in enumerate(func.__closure__): - cell_value = cell.cell_contents - hasher.update(repr(cell_value).encode()) - - hasher.update(str(vec_size).encode()) + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + return hasher.hexdigest() - hash = hasher.hexdigest() - if set_cute_hash: - func.__cute_hash__ = hash - - return hash +def hash_callable(func: Callable, mixer_attrs=_MIXER_ATTRS, set_cute_hash=True) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + if all(v is None for v in mixer_values): + return base_hash + hasher = hashlib.sha256(base_hash.encode()) + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) + return hasher.hexdigest() def create_softcap_scoremod(softcap_val): From 04fca59fceb2ff66925c4e49b189305a511ec766 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 9 Feb 2026 22:13:47 +0000 Subject: [PATCH 09/10] format hash callable --- flash_attn/cute/utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 4d0ad0ccf62..e7f843b9e6b 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -29,14 +29,19 @@ def _compute_base_hash(func: Callable) -> str: data = func.__code__.co_code else: data = repr(func).encode() + hasher = hashlib.sha256(data) + if hasattr(func, "__closure__") and func.__closure__ is not None: for cell in func.__closure__: hasher.update(repr(cell.cell_contents).encode()) + return hasher.hexdigest() -def hash_callable(func: Callable, mixer_attrs=_MIXER_ATTRS, set_cute_hash=True) -> str: +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: """Hash a callable based on the source code or bytecode and closure values. Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` attribute, that value is returned immediately as the base hash, then @@ -49,20 +54,26 @@ def hash_callable(func: Callable, mixer_attrs=_MIXER_ATTRS, set_cute_hash=True) else: # Unwrap decorated functions (e.g., cute.jit wrappers). base_func = getattr(func, "__wrapped__", func) + if hasattr(base_func, "__cute_hash__"): base_hash = base_func.__cute_hash__ else: base_hash = _compute_base_hash(base_func) + if set_cute_hash: base_func.__cute_hash__ = base_hash # Mix in mutable metadata dunders mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + if all(v is None for v in mixer_values): return base_hash + hasher = hashlib.sha256(base_hash.encode()) + for attr, val in zip(_MIXER_ATTRS, mixer_values): hasher.update(f"{attr}={val!r}".encode()) + return hasher.hexdigest() From eaa4a49840be1f4ce4903943a825f3ce33c8a673 Mon Sep 17 00:00:00 2001 From: reubenconducts Date: Mon, 9 Feb 2026 22:39:03 +0000 Subject: [PATCH 10/10] shorten vec size tests --- tests/cute/test_score_mod.py | 6 ++++-- tests/cute/test_score_mod_varlen.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index b2a5ae21470..740d7ac7699 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -107,6 +107,8 @@ (4224, 4224), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] + def create_tensors( batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 @@ -220,7 +222,7 @@ def test_cute_score_mod_vectorized( out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) - for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) @@ -337,7 +339,7 @@ def test_cute_score_mod_with_aux_tensors_vectorized( out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) - for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 08a5695c050..8b5749aa161 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -172,6 +172,8 @@ ([1, 1, 1], [256 * 1024] * 3), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] + # ============================================================================= # Helper functions # ============================================================================= @@ -576,7 +578,7 @@ def test_varlen_with_score_mod_vectorized( cu_seqlens_k=cu_seqlens_k, ) - for vec_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( q,