diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index ed4b48f04c..5c9db14959 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -965,6 +965,7 @@ def cmdGenFunc_mha_batch_prefill( is_causal: bool, window_size_left: int, window_size_right: int, + sink_size: int, return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[Tensor] = None, @@ -1046,6 +1047,17 @@ def cmdGenFunc_mha_batch_prefill( # PERTENSOR: per-tensor quantization md_name += "_pertensor" filter_fwd += "_pertensor*" + # Sink only applies when there is a causal/window mask; full attention + # (window_size_left==-1 and window_size_right==-1) ignores sink_size. + has_effective_sink = sink_size > 0 and ( + causal or not (window_size_left == -1 and window_size_right == -1) + ) + if has_effective_sink: + md_name += "_sink" + filter_fwd += "_sink*" + else: + md_name += "_nsink" + filter_fwd += "_nsink*" blob_gen_cmd = [ f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill " "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) @@ -2739,6 +2751,7 @@ def mha_batch_prefill_fake_tensors( is_causal: bool, window_size_left: int, window_size_right: int, + sink_size: int, return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[torch.Tensor] = None, @@ -2823,6 +2836,7 @@ def mha_batch_prefill( is_causal: bool, window_size_left: int, window_size_right: int, + sink_size: int, return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[Tensor] = None, @@ -2857,6 +2871,7 @@ def _mha_batch_prefill( logits_soft_cap: float = 0.0, window_size_left: int = -1, window_size_right: int = -1, + sink_size: int = 0, bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -2892,6 +2907,7 @@ def _mha_batch_prefill( causal, window_size_left, window_size_right, + sink_size, return_lse, return_softmax, out, @@ -2906,7 +2922,6 @@ def _mha_batch_prefill( seqlen_k, sink_ptr, None, - # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -2938,6 +2953,7 @@ def mha_batch_prefill_func( v_descale=None, kv_block_descale=None, # [num_block, num_kv_head, 2] per-page K/V descales sink_ptr=None, + sink_size: int = 0, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2990,6 +3006,7 @@ def mha_batch_prefill_func( logits_soft_cap=logits_soft_cap, window_size_left=window_size[0], window_size_right=window_size[1], + sink_size=sink_size, alibi_slopes=alibi_slopes, return_lse=return_lse, return_softmax=return_attn_probs and dropout_p > 0, diff --git a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu index 7994e7b2d9..daf298d083 100644 --- a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu +++ b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu @@ -16,7 +16,8 @@ get_mha_batch_prefill_traits(int head_size_q, ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout, ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table, int page_size, - bool skip_min_seqlen_q = false) + bool skip_min_seqlen_q = false, + bool has_sink = false) { return mha_batch_prefill_traits(head_size_q, head_size_v, @@ -29,6 +30,7 @@ get_mha_batch_prefill_traits(int head_size_q, has_dropout, qscale_type, skip_min_seqlen_q, + has_sink, kv_memory_layout, kv_lookup_table, page_size); @@ -47,13 +49,14 @@ float mha_batch_prefill(mha_batch_prefill_args args, int head_size_q = args.hdim_q; int head_size_v = args.hdim_v; bool has_dropout = args.p_drop > 0.f; + bool has_sink = args.sink_size > 0; // The kUseGlobalLoad decision (>2GB KV cache → use `global_load_lds_*` // instead of SRD `buffer_load_*`) is made per-arm inside the auto-generated // dispatcher in fmha_batch_prefill_api.cpp, where each arm knows its own // compile-time bn0 and dtype element size. The wrapper just forwards args; // no runtime trait field for it. - auto traits = get_mha_batch_prefill_traits(head_size_q, + auto traits = get_mha_batch_prefill_traits(head_size_q, head_size_v, q_dtype_str, is_group_mode, @@ -65,7 +68,9 @@ float mha_batch_prefill(mha_batch_prefill_args args, qscale_type, args.kv_memory_layout, args.kv_lookup_table, - args.page_block_size); + args.page_block_size, + /*skip_min_seqlen_q=*/false, + has_sink); return fmha_batch_prefill(traits, args, stream_config); } diff --git a/csrc/include/mha_fwd.h b/csrc/include/mha_fwd.h index 788fb2a008..d191ab652d 100644 --- a/csrc/include/mha_fwd.h +++ b/csrc/include/mha_fwd.h @@ -63,6 +63,7 @@ struct mha_batch_prefill_traits : public fmha_batch_prefill_traits bool has_dropout, quant_scale_enum qscale_type, bool skip_min_seqlen_q, + bool has_sink, ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout, ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table, int page_size) @@ -78,7 +79,7 @@ struct mha_batch_prefill_traits : public fmha_batch_prefill_traits has_dropout, qscale_type, skip_min_seqlen_q, - false, // has_sink + has_sink, kv_memory_layout, kv_lookup_table, page_size} diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index d1c4a0360b..add0c0ad2e 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1097,6 +1097,7 @@ namespace py = pybind11; py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ + py::arg("sink_size"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ diff --git a/csrc/include/torch/mha_batch_prefill.h b/csrc/include/torch/mha_batch_prefill.h index 8d7b510639..f911d6b085 100644 --- a/csrc/include/torch/mha_batch_prefill.h +++ b/csrc/include/torch/mha_batch_prefill.h @@ -21,6 +21,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] bool is_causal, int window_size_left, int window_size_right, + int sink_size, bool return_softmax_lse, bool return_dropout_randval, std::optional out_, // [total_q, hq, d] diff --git a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu index 15a4878ed9..bb7f1bb91b 100644 --- a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu +++ b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu @@ -304,7 +304,7 @@ get_ck_fmha_batch_prefill_args(bool has_lse, kv_last_page_lens_ptr = kv_last_page_lens.data_ptr(); } - fmha_batch_prefill_args args; + fmha_batch_prefill_args args{}; // zero-initialize all fields args.q_ptr = q.data_ptr(); args.k_ptr = k.data_ptr(); @@ -312,6 +312,9 @@ get_ck_fmha_batch_prefill_args(bool has_lse, args.q_descale_ptr = q_descale.has_value() ? q_descale.value().data_ptr() : nullptr; args.k_descale_ptr = k_descale.has_value() ? k_descale.value().data_ptr() : nullptr; args.v_descale_ptr = v_descale.has_value() ? v_descale.value().data_ptr() : nullptr; + // sink_ptr is independent of sink_size: when provided, the kernel always reads + // it as per-head logit values for the virtual sink token (sink_value = *ptr / scale_s). + // When null, sink_value defaults to -inf (virtual token excluded from softmax). args.sink_ptr = sink_ptr_.has_value() ? sink_ptr_.value().data_ptr() : nullptr; args.bias_ptr = bias_ptr; args.rand_val_ptr = has_dropout_randval ? dropout_randval.data_ptr() : nullptr; @@ -363,6 +366,7 @@ get_ck_fmha_batch_prefill_args(bool has_lse, args.batch_stride_o = batch_stride_o; args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = mask.sink; args.mask_type = static_cast(mask.type); args.p_drop = p_dropout; args.s_randval = has_dropout_randval; @@ -416,6 +420,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] bool is_causal, int window_size_left, int window_size_right, + int sink_size, bool return_softmax_lse, bool return_dropout_randval, std::optional out_, // [total_q, hq, d] @@ -609,18 +614,18 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] { // Causal is the special case where window_size_right == 0 and window_size_left < 0. window_size_right = 0; - std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0" + "," + std::to_string(sink_size); mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual } else if(window_size_left == -1 && window_size_right == -1) { - mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask; sink N/A for full attention } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = - "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right) + "," + std::to_string(sink_size); mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local } diff --git a/op_tests/test_batch_prefill.py b/op_tests/test_batch_prefill.py index ee5489fec4..482ca4a42e 100644 --- a/op_tests/test_batch_prefill.py +++ b/op_tests/test_batch_prefill.py @@ -2557,3 +2557,305 @@ def run_batch_prefill_kv_blockscale( total = len(collected) print(f"\nTotal: {total}, Passed: {passed}, Skipped: {skipped}") print("=" * 100) + + +# ============================================================================= +# StreamLLM Sink Token Tests +# ============================================================================= + + +def ref_masked_attention_with_sink( + query, + key, + value, + window_left, + sink_size, + sink_ptr_value, +): + """ + Reference attention with StreamLLM sink semantics. + + Args: + query: [seqlen_q, num_heads, head_dim] + key: [seqlen_k, num_heads, head_dim] + value: [seqlen_k, num_heads, head_dim] + window_left: left window size (-1 = infinite) + sink_size: number of KV tokens at start always attended + sink_ptr_value: per-head float tensor [num_heads] or None. + When not None, a virtual sink token with this scaled + logit is appended to the attention matrix (it steals + probability mass but has no V contribution). + + Valid KV range for query at absolute position abs_q = seqlen_k - seqlen_q + i_q: + k < sink_size (sink region, always valid) + OR + abs_q - window_left <= k <= abs_q (window region, window_left=-1 means k >= 0) + """ + head_dim = query.shape[2] + seqlen_q = query.shape[0] + seqlen_k = key.shape[0] + num_heads = query.shape[1] + scale = 1.0 / math.sqrt(head_dim) + + # [num_heads, seqlen_q, seqlen_k] + attn = scale * torch.einsum("qhd,khd->hqk", query.float(), key.float()) + + # Build mask vectorized to avoid per-element GPU synchronization + # i_q: [seqlen_q, 1], i_k: [1, seqlen_k] + i_q = torch.arange(seqlen_q, device=query.device).unsqueeze(1) # [sq, 1] + i_k = torch.arange(seqlen_k, device=query.device).unsqueeze(0) # [1, sk] + abs_q = seqlen_k - seqlen_q + i_q # [sq, 1] + k_end = abs_q # causal boundary + if window_left < 0: + k_start_window = torch.zeros_like(abs_q) + else: + k_start_window = torch.clamp(abs_q - window_left, min=sink_size) + is_sink = i_k < sink_size # [1, sk] + is_window = (i_k >= k_start_window) & (i_k <= k_end) # [sq, sk] + valid = is_sink | is_window # [sq, sk] + # attn: [H, sq, sk] — broadcast mask over heads + attn.masked_fill_(~valid.unsqueeze(0), float("-inf")) + + if sink_ptr_value is not None: + # Append virtual sink token column: logit = sink_ptr_value[h] (scaled space) + # Shape: [num_heads, seqlen_q, 1] + virt = sink_ptr_value.float().view(num_heads, 1, 1).expand(-1, seqlen_q, 1) + attn_ext = torch.cat([attn, virt], dim=-1) # [H, sq, sk+1] + P_ext = torch.softmax(attn_ext, dim=-1) + P = P_ext[:, :, :seqlen_k] # drop virtual column (V contribution = 0) + else: + P = torch.softmax(attn, dim=-1) + + out = torch.einsum("hqk,khd->qhd", P, value.float()) + return out.to(query.dtype) + + +def run_batch_prefill_sink( + batch_size, + qo_len, + kv_len, + page_size, + num_qo_heads, + num_kv_heads, + head_dim, + window_left, + sink_size, + sink_ptr_value, + dtype, + seed, +): + """ + Run batch_prefill with sink tokens and compare against torch reference. + + sink_ptr_value: float or None. When float, a sink_ptr tensor of shape + [num_qo_heads] filled with this value is passed to the kernel. + """ + if seed is not None: + torch.manual_seed(seed) + + k_vector_size = get_vector_size(dtype) + + # kv_len must be large enough to create a real gap between sink and window + if skip_test_if( + kv_len <= sink_size + window_left + 1, + f"kv_len={kv_len} too small for gap (need >{sink_size + window_left + 1})", + ): + return {"status": "skipped"} + + qo_lens = build_qo_lens(batch_size, qo_len, randomize=batch_size > 1) + kv_lens = build_kv_lens(batch_size, kv_len, qo_lens, randomize=batch_size > 1) + max_qo_len = qo_lens.max().item() + max_kv_len = kv_lens.max().item() + q_indptr_cpu = convert_lens_to_indptr(qo_lens) + + total_q = q_indptr_cpu[-1] + q = build_q_tensor(total_q, num_qo_heads, head_dim, dtype, -5, 5) + + kv_cache = build_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + head_dim, + kv_lens, + -5, + 5, + dtype, + contiguous_kv=True, + ) + kv_data_fp32 = kv_cache["kv_data_fp32"] + kv_indices_cpu = kv_cache["kv_indices_cpu"] + kv_indptr_cpu_cache = kv_cache["kv_indptr_cpu"] + kv_last_page_len_cpu = kv_cache["kv_last_page_len_cpu"] + + k_cache_ref, v_cache_ref = extract_kv_caches(kv_cache, contiguous_kv=True) + k_cache, v_cache = apply_kv_layout( + k_cache_ref, + v_cache_ref, + num_kv_heads, + head_dim, + page_size, + k_vector_size, + "vectorized", + ) + + # Build sink_ptr tensor + sink_ptr = None + if sink_ptr_value is not None: + sink_ptr = torch.full( + (num_qo_heads,), sink_ptr_value, dtype=torch.float32, device="cuda" + ) + + # ── Torch reference ────────────────────────────────────────────────────── + # kv_data_fp32: [total_pages, 2, page_size, num_kv_heads, head_dim] + # dim 1: 0=K, 1=V + o_ref_list = [] + for i in range(batch_size): + used_idx = kv_indices_cpu[kv_indptr_cpu_cache[i] : kv_indptr_cpu_cache[i + 1]] + last_len = kv_last_page_len_cpu[i].item() + + # Full pages: [num_full_pages, page_size, num_kv_heads, head_dim] + # Last page: [:last_len, num_kv_heads, head_dim] + ki = torch.cat( + [ + kv_data_fp32[used_idx[:-1], 0].reshape(-1, num_kv_heads, head_dim), + kv_data_fp32[used_idx[-1], 0, :last_len].reshape( + -1, num_kv_heads, head_dim + ), + ], + dim=0, + ).to(dtype) + vi = torch.cat( + [ + kv_data_fp32[used_idx[:-1], 1].reshape(-1, num_kv_heads, head_dim), + kv_data_fp32[used_idx[-1], 1, :last_len].reshape( + -1, num_kv_heads, head_dim + ), + ], + dim=0, + ).to(dtype) + + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + + if num_qo_heads != num_kv_heads: + ratio = num_qo_heads // num_kv_heads + ki = ki.repeat_interleave(ratio, dim=1) + vi = vi.repeat_interleave(ratio, dim=1) + + o_ref_list.append( + ref_masked_attention_with_sink(qi, ki, vi, window_left, sink_size, sink_ptr) + ) + o_ref = torch.cat(o_ref_list, dim=0) + + # ── CK kernel ───────────────────────────────────────────────────────────── + kv_indptr_gpu = kv_indptr_cpu_cache.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_lens = kv_last_page_len_cpu.to(0) + cu_seqlens_q = q_indptr_cpu.to(0) + + out = aiter.mha_batch_prefill_func( + q, + k_cache, + v_cache, + cu_seqlens_q, + kv_indptr_gpu, + kv_indices_gpu, + max_seqlen_q=max_qo_len, + max_seqlen_k=max_kv_len, + causal=True, + window_size=(window_left, -1), + sink_size=sink_size, + sink_ptr=sink_ptr, + kv_last_page_lens=kv_last_page_lens, + return_lse=False, + ) + + # ── Compare ─────────────────────────────────────────────────────────────── + rtol, atol = get_tolerances(dtype) + assert_output_matches_reference(out, q_indptr_cpu, o_ref, rtol, atol) + return {"status": "passed"} + + +@pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "sink_ptr_value", + [None, 0.0, 2.0], + ids=["ptr=None", "ptr=0.0", "ptr=2.0"], +) +@pytest.mark.parametrize("sink_size", [4, 16]) +@pytest.mark.parametrize( + "window_left,kv_len", + [(128, 512), (1024, 2048)], + ids=["win=128/kv=512", "win=1024/kv=2048"], +) +@pytest.mark.parametrize("qo_len", [32, 128]) +@pytest.mark.parametrize("num_qo_heads,num_kv_heads", [(8, 1), (4, 4)]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("page_size", [16]) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_batch_prefill_sink( + batch_size, + page_size, + head_dim, + num_qo_heads, + num_kv_heads, + qo_len, + window_left, + kv_len, + sink_size, + sink_ptr_value, + dtype, + seed, +): + """ + Test batch_prefill with StreamLLM sink token support. + + Validates: + - sink_size: first sink_size KV positions always attended (never window-masked) + - sink_ptr: virtual sink token with fixed logit participates in softmax + - window_left + sink_size creates a real gap; gap tokens are correctly masked + """ + run_batch_prefill_sink( + batch_size=batch_size, + qo_len=qo_len, + kv_len=kv_len, + page_size=page_size, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + window_left=window_left, + sink_size=sink_size, + sink_ptr_value=sink_ptr_value, + dtype=dtype, + seed=seed, + ) + + +# CI runs `python3 test_batch_prefill.py` (no pytest), so the __main__ block +# above only executes the non-sink scenarios. Add a small representative sweep +# of the StreamLLM sink scenarios here so they actually exercise in CI. +if __name__ == "__main__": + sink_cases = list( + itertools.product( + [(128, 512), (1024, 2048)], # (window_left, kv_len) + [4], # sink_size + [None, 2.0], # sink_ptr_value + [torch.bfloat16], # dtype + ) + ) + for (window_left, kv_len), sink_size, sink_ptr_value, dtype in sink_cases: + run_batch_prefill_sink( + batch_size=1, + qo_len=128, + kv_len=kv_len, + page_size=16, + num_qo_heads=8, + num_kv_heads=1, + head_dim=128, + window_left=window_left, + sink_size=sink_size, + sink_ptr_value=sink_ptr_value, + dtype=dtype, + seed=42, + )