diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 9d2f7aa739b..ec750e8179b 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -152,6 +152,35 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena return tensor.mark_layout_dynamic(leading_dim=leading_dim) +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_aux_tensor_metadata(aux_tensors): + return tuple( + ( + getattr(t, "__assumed_align__", 0), + getattr(t, "__leading_dim__", -1), + hasattr(t, "__leading_dim__"), + ) + for t in aux_tensors + ) + + def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: """Return tuple of bools indicating which dims have stride=0 (broadcast). diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 9eaccda41bc..bba612bc4cb 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -113,10 +113,9 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) @staticmethod def can_implement( diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 886d02632a5..82d091f199f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -131,10 +131,9 @@ def __init__( ) self.score_mod = score_mod self.mask_mod = mask_mod - if cutlass.const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8d936602e31..ef2df23e448 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -41,7 +41,7 @@ from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import to_cute_tensor +from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess @@ -368,7 +368,11 @@ def _flash_attn_fwd( block_size=(m_block_size, n_block_size), q_stage=q_stage, ) - + if aux_tensors is not None: + aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) + else: + aux_tensor_metadata = None + compile_key = ( dtype, head_dim, @@ -379,7 +383,7 @@ def _flash_attn_fwd( mask_mod_hash, use_block_sparsity, block_sparse_broadcast_pattern, - len(aux_tensors) if aux_tensors is not None else 0, + aux_tensor_metadata, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, @@ -432,8 +436,9 @@ def _flash_attn_fwd( sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None + aux_tensor_metadata = None if aux_tensors is not None: - cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index f2383e89415..e7f843b9e6b 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -17,26 +17,11 @@ import quack.activation +_MIXER_ATTRS = ("__vec_size__",) -def hash_callable(func: Callable, set_cute_hash=True) -> str: - """Hash a callable based on the source code or bytecode and closure values. - - Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` - attribute, that value is returned immediately. Code-generation backends such - as Inductor can set this attribute to avoid expensive runtime hashing. - - set_cute_hash: whether or not to set func.__cute_hash__ if not present - """ - if hasattr(func, "__cute_hash__"): - return func.__cute_hash__ - - # Unwrap decorated functions (e.g., cute.jit wrappers). - if hasattr(func, "__wrapped__"): - base_func = func.__wrapped__ - if hasattr(base_func, "__cute_hash__"): - return base_func.__cute_hash__ - func = base_func +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" try: data = inspect.getsource(func).encode() except (OSError, TypeError): @@ -48,16 +33,48 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: hasher = hashlib.sha256(data) if hasattr(func, "__closure__") and func.__closure__ is not None: - for idx, cell in enumerate(func.__closure__): - cell_value = cell.cell_contents - hasher.update(repr(cell_value).encode()) + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash - hash = hasher.hexdigest() + hasher = hashlib.sha256(base_hash.encode()) - if set_cute_hash: - func.__cute_hash__ = hash + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) - return hash + return hasher.hexdigest() def create_softcap_scoremod(softcap_val): diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py index be6333a6448..aaa3664abf0 100644 --- a/tests/cute/score_mod_definitions.py +++ b/tests/cute/score_mod_definitions.py @@ -15,12 +15,28 @@ def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa +@cute.jit +def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa + + @cute.jit def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask = operator.ge(q_idx, kv_idx) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) +@cute.jit +def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean) + kv_idx0 = kv_idx[0] + q_idx0 = q_idx[0] + for i in cutlass.range_constexpr(cute.size(mask.shape)): + mask[i] = q_idx0 >= kv_idx0 + i + mask_ssa = mask.load() + return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + @cute.jit def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -28,6 +44,18 @@ def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa + abs_diff.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return tSrS_ssa + abs_diff.load().to(cutlass.Float32) + + @cute.jit def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -36,10 +64,25 @@ def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + scaled.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_x2_vectorized( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff_x2[i] = mlir_math.absi(diffi) * 2 + return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32) + + @cute.jit def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa * cute.full_like(tSrS_ssa, 2) +score_mod_times_two_vectorized = score_mod_times_two @cute.jit def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -53,6 +96,21 @@ def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tens abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) return score - slope * abs_diff +@cute.jit +def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + score = tSrS_ssa.to(cutlass.Float32) + slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) + slope = cute.math.exp2( + slope_exp.to(cutlass.Float32) + * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) + ) + diff0 = q_idx[0] - kv_idx[0] + abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) + for i in cutlass.range_constexpr(cute.size(abs_diff.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return score - slope * abs_diff.load().to(cutlass.Float32) + @cute.jit def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -88,6 +146,16 @@ def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux bias_val = (bias_frag.load()).to(cutlass.Float32) return tSrS_ssa + bias_val +@cute.jit +def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + batch_bias = aux_tensors[0] + dtype = batch_bias.element_type + b_idx0 = b_idx[0] + bias_frag = cute.make_rmem_tensor(1, dtype) + bias_frag[0] = batch_bias[b_idx0] + bias_val = (bias_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias_val + @cute.jit def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -109,6 +177,22 @@ def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + head_val + pos_val +@cute.jit +def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + dtype = head_bias.element_type + + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_idx[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_idx[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + # ============================================================================= # Score_mod functions that use global indices diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 11efcc8cdbc..740d7ac7699 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -23,6 +23,16 @@ score_mod_batch_bias as score_mod_10, score_mod_dual_buffer as score_mod_11, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized as score_mod_1_vectorized, + score_mod_causal_vectorized as score_mod_2_vectorized, + score_mod_rel_bias as score_mod_3_vectorized, + score_mod_rel_bias_x2_vectorized as score_mod_4_vectorized, + score_mod_times_two_vectorized as score_mod_5_vectorized, + score_mod_alibi_vectorized as score_mod_6_vectorized, + score_mod_batch_bias_vectorized as score_mod_10_vectorized, + score_mod_dual_buffer_vectorized as score_mod_11_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -59,6 +69,21 @@ (score_mod_11, dual_buffer_bias), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED = [ + (score_mod_1, score_mod_1_vectorized), + (score_mod_2, score_mod_2_vectorized), + (score_mod_3, score_mod_3_vectorized), + (score_mod_4, score_mod_4_vectorized), + (score_mod_5, score_mod_5_vectorized), + (score_mod_6, score_mod_6_vectorized), +] + +TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED = [ + (score_mod_10, score_mod_10_vectorized), + (score_mod_11, score_mod_11_vectorized), +] + SEQLEN_CONFIGS = [ (1, 1), (64, 128), @@ -82,6 +107,8 @@ (4224, 4224), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] + def create_tensors( batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 @@ -92,12 +119,8 @@ def create_tensors( return q, k, v -def run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False -) -> torch.Tensor: - q_transposed, k_transposed, v_transposed = map( - lambda x: x.transpose(1, 2), (q, k, v) - ) +def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map(lambda x: x.transpose(1, 2), (q, k, v)) out = torch.empty_like(q_transposed) _flash_attn_fwd( q_transposed, @@ -116,9 +139,7 @@ def run_cute_flash( def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - return flex_attention( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @@ -174,6 +195,40 @@ def test_cute_vs_flex_attention( ) +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_VECTORIZED) +def test_cute_score_mod_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, +): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) + + assert torch.equal(out, out_ref) + + @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -214,9 +269,7 @@ def test_cute_vs_flex_attention_with_aux_tensors( out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa - ) + out_cute = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -247,19 +300,61 @@ def test_cute_vs_flex_attention_with_aux_tensors( ) -def _generate_block_kvcache( - seqlen_k, page_size, batch_size, nheads_k, d, device, dtype +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED) +def test_cute_score_mod_with_aux_tensors_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, ): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + batch_size = 2 + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [buffer] + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + assert head_bias.shape == (num_q_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) + + assert torch.equal(out, out_ref) + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): import math from einops import rearrange num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + k_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) + v_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -321,12 +416,8 @@ def test_score_mod_with_paged_kvcache( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -342,9 +433,7 @@ def test_score_mod_with_paged_kvcache( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) from einops import rearrange @@ -426,9 +515,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -478,12 +565,8 @@ def test_score_mod_with_paged_kvcache_aux_tensors( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -499,9 +582,7 @@ def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 @@ -595,9 +676,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -628,7 +707,7 @@ def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info @cute.jit def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). - + At unmasked positions (q_idx >= kv_idx), grad passes through. At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. """ @@ -678,7 +757,9 @@ def run_cute_flash_bwd( v_t = v.transpose(1, 2) out, lse = _flash_attn_fwd( - q_t, k_t, v_t, + q_t, + k_t, + v_t, return_lse=True, score_mod=cute_score_mod, aux_tensors=aux_tensors, @@ -688,8 +769,12 @@ def run_cute_flash_bwd( grad_out = torch.randn_like(out) dq, dk, dv = _flash_attn_bwd( - q_t, k_t, v_t, - out, grad_out, lse, + q_t, + k_t, + v_t, + out, + grad_out, + lse, score_mod=cute_score_mod, score_mod_bwd=cute_score_mod_bwd, aux_tensors=aux_tensors, @@ -718,9 +803,7 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): v = v.requires_grad_(True) compiled_flex = torch.compile(flex_attention) - out = compiled_flex( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + out = compiled_flex(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out) return out, dq, dk, dv @@ -755,15 +838,11 @@ def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_ seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) - out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( - q, k, v, cute_fwd, cute_bwd - ) + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any(), "dQ contains NaN" assert not torch.isnan(dk_cute).any(), "dK contains NaN" @@ -839,9 +918,7 @@ def test_cute_vs_flex_attention_backward_with_aux( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() @@ -901,9 +978,7 @@ def test_cute_vs_flex_attention_backward_pack_gqa( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 7cca7f2aa0a..8b5749aa161 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -28,6 +28,16 @@ score_mod_stress_xor_pattern, score_mod_times_two, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized, + score_mod_causal_vectorized, + score_mod_rel_bias as score_mod_rel_bias_vectorized, + score_mod_rel_bias_x2_vectorized, + score_mod_times_two_vectorized, + score_mod_alibi_vectorized, + score_mod_batch_bias_vectorized, + score_mod_dual_buffer_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -77,6 +87,17 @@ (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED_NO_GLOBAL = [ + (score_mod_identity, score_mod_identity_vectorized, None), + (score_mod_causal, score_mod_causal_vectorized, None), + (score_mod_rel_bias, score_mod_rel_bias_vectorized, None), + (score_mod_rel_bias_x2, score_mod_rel_bias_x2_vectorized, None), + (score_mod_times_two, score_mod_times_two_vectorized, None), + (score_mod_alibi, score_mod_alibi_vectorized, None), + (score_mod_batch_bias, score_mod_batch_bias_vectorized, "batch"), + (score_mod_dual_buffer, score_mod_dual_buffer_vectorized, "dual_buffer"), +] # (cute_score_mod, eager_factory, aux_type, requires_global) # aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" # requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) @@ -151,6 +172,8 @@ ([1, 1, 1], [256 * 1024] * 3), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] + # ============================================================================= # Helper functions # ============================================================================= @@ -488,6 +511,87 @@ def test_varlen_with_score_mod( cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_vec_tuple", TEST_PAIRS_VECTORIZED_NO_GLOBAL) +def test_varlen_with_score_mod_vectorized( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_tuple, +): + """Tests equality between original and vectorized versions of score mods""" + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod, aux_type = score_mod_vec_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + out_ref = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, + k, + v, + cute_vectorized_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + assert torch.equal(out, out_ref) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False])