From a2da1445c36295e89b0b75c963d5e8ff4c5f726b Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Mon, 27 Oct 2025 16:19:24 -0700 Subject: [PATCH 1/2] Enable large batch size and optimization of non-Ragged batching --- .../ops/triton/_triton_kernels/lean_atten.py | 249 +++++++++++++----- aiter/ops/triton/lean_atten.py | 10 +- op_tests/op_benchmarks/triton/bench_la.py | 9 +- op_tests/triton_tests/test_la.py | 34 +-- 4 files changed, 211 insertions(+), 91 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/lean_atten.py b/aiter/ops/triton/_triton_kernels/lean_atten.py index 11b4a3759f..0ca0959169 100644 --- a/aiter/ops/triton/_triton_kernels/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/lean_atten.py @@ -261,6 +261,7 @@ def la_persistent( num_heads_k: tl.constexpr, gqa_group_size: tl.constexpr, use_64_indexing: tl.constexpr, + RAGGED_BATCH: tl.constexpr, ): if is_pod: current_pid = pod_pid @@ -356,6 +357,7 @@ def la_persistent( num_splits=num_splits, gqa_group_size=gqa_group_size, use_64_indexing=use_64_indexing, + RAGGED_BATCH=RAGGED_BATCH, ) @@ -410,6 +412,7 @@ def la_persistent_inner( num_splits: tl.constexpr, gqa_group_size: tl.constexpr, use_64_indexing: tl.constexpr, + RAGGED_BATCH: tl.constexpr, ): tl.assume(stride_qm > 0) # n_ctx_q @@ -468,23 +471,30 @@ def la_persistent_inner( tile_head_idx * batch_size + tile_batch_idx ) * num_m_blocks + per_head_tile_idx else: - tile_idx = ( - tile_head_idx * batch_size - ) # Output tile idx, 1 output tile per head per batch - tile_iter = tile_head_idx * tiles_per_head - if batch_size == 1: - req_size = tiles_per_head + if not RAGGED_BATCH: + group_size = tiles_per_head // batch_size + tile_batch_idx = (iter % tiles_per_head) // group_size + tile_idx = tile_head_idx * batch_size + tile_batch_idx + tile_iter = tile_head_idx * tiles_per_head + (tile_batch_idx * group_size) + tile_iter_end = tile_iter + group_size else: - req_size = tl.load(batch_num_block_n) - tile_iter_end = tile_iter + req_size - for b in range(1, batch_size): - next_req_size = tl.load(batch_num_block_n + b) - local_head_iter = iter % tiles_per_head - if (local_head_iter < next_req_size) and (local_head_iter >= req_size): - tile_iter = tile_iter + req_size - tile_idx = tile_idx + b - tile_iter_end = tile_iter + (next_req_size - req_size) - req_size = next_req_size + tile_idx = ( + tile_head_idx * batch_size + ) # Output tile idx, 1 output tile per head per batch + tile_iter = tile_head_idx * tiles_per_head + if batch_size == 1: + req_size = tiles_per_head + else: + req_size = tl.load(batch_num_block_n) + tile_iter_end = tile_iter + req_size + for b in range(1, batch_size): + next_req_size = tl.load(batch_num_block_n + b) + local_head_iter = iter % tiles_per_head + if (local_head_iter < next_req_size) and (local_head_iter >= req_size): + tile_iter = tile_iter + req_size + tile_idx = tile_idx + b + tile_iter_end = tile_iter + (next_req_size - req_size) + req_size = next_req_size # Local lean tile ID within a loop of an output tile local_iter = iter - tile_iter local_iter_end = tl.minimum(tile_iter_end, cta_end_tile_gid) - tile_iter @@ -510,9 +520,11 @@ def la_persistent_inner( offs_k = tl.arange(0, HEAD_DIM) mask_k_cols = offs_k < HEAD_DIM_ORIG - if causal: + if causal or not RAGGED_BATCH: + # Prefill or non RAGGED_BATCH b_seq_size = tile_batch_idx * num_n_blocks else: + # Decode with RAGGED_BATCH tile_batch_idx = tile_idx % batch_size b_seq_size = 0 if tile_batch_idx > 0: @@ -520,18 +532,40 @@ def la_persistent_inner( batch_num_block_n + tile_batch_idx - 1 ) # Previous batch size - k_offs = ( - (b_seq_size + local_iter) * BLOCK_N * stride_kn - + tile_khead_idx_global * stride_kh - + offs_n[None, :] * stride_kn - + offs_k[:, None] * stride_kk - ) - v_offs = ( - (b_seq_size + local_iter) * BLOCK_N * stride_vn - + tile_khead_idx_global * stride_vh - + offs_n[:, None] * stride_vn - + offs_k[None, :] * stride_vk - ) + if use_64_indexing: + BLOCK_N64 = tl.full((), BLOCK_N, tl.int64) + stride_kn64 = tl.full((), stride_kn, tl.int64) + stride_vn64 = tl.full((), stride_vn, tl.int64) + stride_kh64 = tl.full((), stride_kh, tl.int64) + stride_vh64 = tl.full((), stride_vh, tl.int64) + stride_kk64 = tl.full((), stride_kk, tl.int64) + stride_vk64 = tl.full((), stride_vk, tl.int64) + bn64 = tl.full((), b_seq_size, tl.int64) + tl.full((), local_iter, tl.int64) + k_offs = ( + (bn64 * BLOCK_N64) * stride_kn64 + + tl.full((), tile_khead_idx_global, tl.int64) * stride_kh64 + + offs_n[None, :] * stride_kn64 + + offs_k[:, None] * stride_kk64 + ) + v_offs = ( + (bn64 * BLOCK_N64) * stride_vn64 + + tl.full((), tile_khead_idx_global, tl.int64) * stride_vh64 + + offs_n[:, None] * stride_vn64 + + offs_k[None, :] * stride_vk64 + ) + else: + k_offs = ( + (b_seq_size + local_iter) * BLOCK_N * stride_kn + + tile_khead_idx_global * stride_kh + + offs_n[None, :] * stride_kn + + offs_k[:, None] * stride_kk + ) + v_offs = ( + (b_seq_size + local_iter) * BLOCK_N * stride_vn + + tile_khead_idx_global * stride_vh + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) k_ptrs = K + k_offs k_ptrs = tl.multiple_of(k_ptrs, (16, 1)) @@ -545,12 +579,27 @@ def la_persistent_inner( q_idx = tile_batch_idx q_start_m = 0 - q_offs = ( - q_idx * BLOCK_M * stride_qm - + tile_head_idx_global * stride_qh - + offs_m[:, None] * stride_qm - + offs_k[None, :] * stride_qk - ) + if use_64_indexing: + q_idx64 = tl.full((), q_idx, tl.int64) + BLOCK_M64 = tl.full((), BLOCK_M, tl.int64) + stride_qm64 = tl.full((), stride_qm, tl.int64) + stride_qk64 = tl.full((), stride_qk, tl.int64) + th64 = tl.full((), tile_head_idx_global, tl.int64) * tl.full( + (), stride_qh, tl.int64 + ) + q_offs = ( + q_idx64 * BLOCK_M64 * stride_qm64 + + th64 + + offs_m[:, None] * stride_qm64 + + offs_k[None, :] * stride_qk64 + ) + else: + q_offs = ( + q_idx * BLOCK_M * stride_qm + + tile_head_idx_global * stride_qh + + offs_m[:, None] * stride_qm + + offs_k[None, :] * stride_qk + ) q_ptrs = Q + q_offs q_ptrs = tl.multiple_of(q_ptrs, (1, 16)) @@ -594,12 +643,27 @@ def la_persistent_inner( # Update pointers of partial results Mp[cta], Lp[cta], Op[cta] mp_ptrs = Mp + current_pid * BLOCK_M + offs_m lp_ptrs = Lp + current_pid * BLOCK_M + offs_m - op_ptrs = ( - Op - + current_pid * stride_oph # stride_oph is total_program dimension - + offs_m[:, None] * stride_opm - + offs_k[None, :] * stride_opn - ) + if use_64_indexing: + current_pid64 = tl.full((), current_pid, tl.int64) + BLOCK_M64 = tl.full((), BLOCK_M, tl.int64) + stride_oph64 = tl.full((), stride_oph, tl.int64) + stride_opm64 = tl.full((), stride_opm, tl.int64) + stride_opn64 = tl.full((), stride_opn, tl.int64) + offs_m64 = tl.full([BLOCK_M], 0, tl.int64) + tl.cast(offs_m, tl.int64) + offs_k64 = tl.full([HEAD_DIM], 0, tl.int64) + tl.cast(offs_k, tl.int64) + op_ptrs = ( + Op + + current_pid64 * stride_oph64 + + offs_m64[:, None] * stride_opm64 + + offs_k64[None, :] * stride_opn64 + ) + else: + op_ptrs = ( + Op + + current_pid * stride_oph # stride_oph is total_program dimension + + offs_m[:, None] * stride_opm + + offs_k[None, :] * stride_opn + ) tl.store(mp_ptrs, m_i, cache_modifier=".wt") tl.store(lp_ptrs, l_i, cache_modifier=".wt") @@ -705,19 +769,41 @@ def la_persistent_inner( offs_mplp = temp_pid * BLOCK_M + offs_m mp_ptrs = Mp + offs_mplp lp_ptrs = Lp + offs_mplp - op_ptrs0 = ( - Op - + temp_pid * stride_oph - + offs_m[:, None] * stride_opm - + tl.arange(0, HEAD_DIM // 2)[None, :] * stride_opn - ) - op_ptrs1 = ( - Op - + temp_pid * stride_oph - + offs_m[:, None] * stride_opm - + (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) - * stride_opn - ) + if use_64_indexing: + temp_pid64 = tl.full((), temp_pid, tl.int64) + stride_oph64 = tl.full((), stride_oph, tl.int64) + stride_opm64 = tl.full((), stride_opm, tl.int64) + stride_opn64 = tl.full((), stride_opn, tl.int64) + offs_m64 = tl.cast(offs_m, tl.int64) + offs0 = tl.arange(0, HEAD_DIM // 2) + offs0_64 = tl.cast(offs0, tl.int64) + offs1_64 = offs0_64 + tl.full((), HEAD_DIM // 2, tl.int64) + op_ptrs0 = ( + Op + + temp_pid64 * stride_oph64 + + offs_m64[:, None] * stride_opm64 + + offs0_64[None, :] * stride_opn64 + ) + op_ptrs1 = ( + Op + + temp_pid64 * stride_oph64 + + offs_m64[:, None] * stride_opm64 + + offs1_64[None, :] * stride_opn64 + ) + else: + op_ptrs0 = ( + Op + + temp_pid * stride_oph + + offs_m[:, None] * stride_opm + + tl.arange(0, HEAD_DIM // 2)[None, :] * stride_opn + ) + op_ptrs1 = ( + Op + + temp_pid * stride_oph + + offs_m[:, None] * stride_opm + + (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) + * stride_opn + ) m_cta = tl.load(mp_ptrs, cache_modifier=".cv") l_cta = tl.load(lp_ptrs, cache_modifier=".cv") @@ -744,20 +830,47 @@ def la_persistent_inner( # host CTA write final result to memory # acc = acc / l_i[:, None] # tl.store(o_ptrs, acc.to(Out.type.element_ty)) - o_ptrs0 = ( - Out - + q_idx * BLOCK_M * stride_om - + tile_head_idx_global * stride_oh - + offs_m[:, None] * stride_om - + tl.arange(0, HEAD_DIM // 2)[None, :] * stride_on - ) - o_ptrs1 = ( - Out - + q_idx * BLOCK_M * stride_om - + tile_head_idx_global * stride_oh - + offs_m[:, None] * stride_om - + (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) * stride_on - ) + if use_64_indexing: + q_idx64 = tl.full((), q_idx, tl.int64) + BLOCK_M64 = tl.full((), BLOCK_M, tl.int64) + stride_om64 = tl.full((), stride_om, tl.int64) + stride_on64 = tl.full((), stride_on, tl.int64) + th64 = tl.full((), tile_head_idx_global, tl.int64) * tl.full( + (), stride_oh, tl.int64 + ) + offs0 = tl.arange(0, HEAD_DIM // 2) + offs0_64 = tl.cast(offs0, tl.int64) + offs1_64 = offs0_64 + tl.full((), HEAD_DIM // 2, tl.int64) + + o_ptrs0 = ( + Out + + q_idx64 * BLOCK_M64 * stride_om64 + + th64 + + offs_m[:, None] * stride_om64 + + offs0_64[None, :] * stride_on64 + ) + o_ptrs1 = ( + Out + + q_idx64 * BLOCK_M64 * stride_om64 + + th64 + + offs_m[:, None] * stride_om64 + + offs1_64[None, :] * stride_on64 + ) + else: + o_ptrs0 = ( + Out + + q_idx * BLOCK_M * stride_om + + tile_head_idx_global * stride_oh + + offs_m[:, None] * stride_om + + tl.arange(0, HEAD_DIM // 2)[None, :] * stride_on + ) + o_ptrs1 = ( + Out + + q_idx * BLOCK_M * stride_om + + tile_head_idx_global * stride_oh + + offs_m[:, None] * stride_om + + (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) * stride_on + ) acc0 = acc0 / l_i[:, None] acc1 = acc1 / l_i[:, None] diff --git a/aiter/ops/triton/lean_atten.py b/aiter/ops/triton/lean_atten.py index bc1ecd21a4..f5ed5d8d77 100644 --- a/aiter/ops/triton/lean_atten.py +++ b/aiter/ops/triton/lean_atten.py @@ -20,6 +20,7 @@ import torch from typing import Optional from bisect import bisect_right +import math import triton import triton.language as tl from aiter.ops.triton._triton_kernels.lean_atten import la_persistent, _get_config @@ -187,6 +188,9 @@ def _persistent_lean_attention( MASKED_BLOCKS=MASKED_BLOCKS, MODE=CAUSAL_MODE, ) + if not causal: + max_output_tile_cnt = math.ceil((H * batch_size) / total_programs) + if DEBUG: print(f"max_output_tile_cnt={max_output_tile_cnt}") @@ -243,8 +247,6 @@ def _persistent_lean_attention( f"locks must have length >= total_programs ({total_programs}), got {locks.numel()}" ) - max_output_tile_cnt = max_output_tile_cnt + 4 - grid = (total_programs, 1, 1) o = torch.empty_like(q, dtype=v.dtype) @@ -321,7 +323,9 @@ def _persistent_lean_attention( or (Op.stride(0) * total_programs) >= (1 << 31) or (Op.stride(1) * N_CTX_Q) >= (1 << 31) or (o.stride(0) * N_CTX_Q) >= (1 << 31) + or (q.stride(0) * N_CTX_Q) >= (1 << 31) ), + RAGGED_BATCH=False, **config, ) """ @@ -332,7 +336,7 @@ def _persistent_lean_attention( kernel_timing[k]["ms"] += ms total_ms = kernel_timing["attn_fwd"]["ms"] """ - # print(f"la kernel {la_kernel.n_regs} registers used, {la_kernel.n_spills} spills") + print(f"la kernel {la_kernel.n_regs} registers used, {la_kernel.n_spills} spills") ms = 0 return (o, ms) diff --git a/op_tests/op_benchmarks/triton/bench_la.py b/op_tests/op_benchmarks/triton/bench_la.py index 7902f8e3b4..d138141def 100644 --- a/op_tests/op_benchmarks/triton/bench_la.py +++ b/op_tests/op_benchmarks/triton/bench_la.py @@ -19,7 +19,7 @@ "hq", "hk", "n_ctx_q", - "n_ctx", + "n_ctx_k", "d", "total_programs", "init_dtype", @@ -100,6 +100,9 @@ (True, 1, 32, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), (True, 1, 64, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), (True, 1, 128, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), + (False, 512, 32, 8, 16, [8192], 128, 608, torch.float16, 16, 64, 2, 4), + (False, 512, 64, 8, 16, [8192], 128, 608, torch.float16, 16, 128, 2, 4), + (False, 512, 128, 8, 16, [8192], 128, 608, torch.float16, 16, 128, 2, 4), ], line_arg="provider", line_vals=["triton"], @@ -121,7 +124,7 @@ def bench_lean_attention( hq, hk, n_ctx_q, - n_ctx, + n_ctx_k, d, total_programs, init_dtype, @@ -132,7 +135,7 @@ def bench_lean_attention( provider, device="cuda", ): - + n_ctx = n_ctx_k * batch assert batch == len(n_ctx) try: diff --git a/op_tests/triton_tests/test_la.py b/op_tests/triton_tests/test_la.py index 98368b7244..504a2a8fdf 100644 --- a/op_tests/triton_tests/test_la.py +++ b/op_tests/triton_tests/test_la.py @@ -369,16 +369,16 @@ def print_mismatches(ref_out, la_out, atol=1e-8, rtol=1e-5): def main(): # (True, 2, 64, 8, 16384, [16384, 16384], 128, 608, torch.float16, 128, 64, 2, 4), - batch = 1 + batch = 1024 causal = False - hq = 128 - hk = 128 - n_ctx_q = 8192 - n_ctx = [8192] * 1 # [16384] #[8192] + hq = 32 + hk = 8 + n_ctx_q = 16 + n_ctx = [8192] * batch # [16384] #[8192] d = 128 total_programs = 304 init_dtype = torch.float16 - BLOCK_M = 128 + BLOCK_M = 16 BLOCK_N = 64 XCD_REMAP = True waves_per_eu = 2 @@ -441,19 +441,19 @@ def main(): ) # print(f"ms={ms}") - # ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal) + ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal) # # Compare result - # atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 - # rtol = 1e-2 if init_dtype == "fp8" else 3e-3 - # try: - # torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) - # except AssertionError: - # print("Assertion failed! Showing mismatches:") - # # print_mismatches(ref_out, la_out, atol, rtol) - # raise # Re-raise the exception after printing mismatches - - # # torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 + rtol = 1e-2 if init_dtype == "fp8" else 3e-3 + try: + torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + except AssertionError: + # print("Assertion failed! Showing mismatches:") + # # print_mismatches(ref_out, la_out, atol, rtol) + raise # Re-raise the exception after printing mismatches + + # torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) if __name__ == "__main__": From 0df65e1da431663d49416de48e4d953c5a6198be Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Tue, 28 Oct 2025 11:17:04 -0700 Subject: [PATCH 2/2] Add RAGGED_BATCH to test_la.py and bench_la.py --- aiter/ops/triton/lean_atten.py | 9 +- op_tests/op_benchmarks/triton/bench_la.py | 207 +++++++++++++++++++-- op_tests/triton_tests/test_la.py | 211 +++++++++++++++++++--- 3 files changed, 386 insertions(+), 41 deletions(-) diff --git a/aiter/ops/triton/lean_atten.py b/aiter/ops/triton/lean_atten.py index f5ed5d8d77..aca61ab770 100644 --- a/aiter/ops/triton/lean_atten.py +++ b/aiter/ops/triton/lean_atten.py @@ -46,6 +46,7 @@ def persistent_lean_attention( batch_size: int, sm_scale: torch.float16, causal: bool = True, # causal masking + RAGGED_BATCH: bool = False, config: Optional[dict] = None, program_count: Optional[int] = None, ): @@ -80,6 +81,7 @@ def persistent_lean_attention( causal=causal, batch_size=batch_size, sm_scale=sm_scale, + RAGGED_BATCH=RAGGED_BATCH, num_warps=config["num_warps"], waves_per_eu=config["waves_per_eu"], config=config, @@ -103,6 +105,7 @@ def _persistent_lean_attention( causal: bool, # causal masking batch_size: int, sm_scale: torch.float16, # typically 1 / sqrt(d) + RAGGED_BATCH: bool, num_warps: int, waves_per_eu: int, config: dict = {}, @@ -189,7 +192,7 @@ def _persistent_lean_attention( MODE=CAUSAL_MODE, ) if not causal: - max_output_tile_cnt = math.ceil((H * batch_size) / total_programs) + max_output_tile_cnt = math.ceil((H * batch_size) / total_programs) + 4 if DEBUG: print(f"max_output_tile_cnt={max_output_tile_cnt}") @@ -325,7 +328,7 @@ def _persistent_lean_attention( or (o.stride(0) * N_CTX_Q) >= (1 << 31) or (q.stride(0) * N_CTX_Q) >= (1 << 31) ), - RAGGED_BATCH=False, + RAGGED_BATCH=RAGGED_BATCH, **config, ) """ @@ -336,7 +339,7 @@ def _persistent_lean_attention( kernel_timing[k]["ms"] += ms total_ms = kernel_timing["attn_fwd"]["ms"] """ - print(f"la kernel {la_kernel.n_regs} registers used, {la_kernel.n_spills} spills") + # print(f"la kernel {la_kernel.n_regs} registers used, {la_kernel.n_spills} spills") ms = 0 return (o, ms) diff --git a/op_tests/op_benchmarks/triton/bench_la.py b/op_tests/op_benchmarks/triton/bench_la.py index d138141def..457f2c2b5b 100644 --- a/op_tests/op_benchmarks/triton/bench_la.py +++ b/op_tests/op_benchmarks/triton/bench_la.py @@ -25,6 +25,7 @@ "init_dtype", "BLOCK_M", "BLOCK_N", + "RAGGED_BATCH", "waves_per_eu", "num_warps", ], @@ -91,18 +92,198 @@ # ), # Causal=1, # (True, 2, 64, 64, 2048, [2048, 2048], 128, 608, torch.float16, 128, 64, 2, 4), # Diff here - (True, 1, 32, 8, 8192, [8192], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 64, 8, 8192, [8192], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 128, 8, 8192, [8192], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 32, 16, 1024, [1024], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 64, 16, 1024, [1024], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 128, 16, 1024, [1024], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 32, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 64, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), - (True, 1, 128, 32, 2048, [2048], 128, 608, torch.float16, 128, 64, 2, 4), - (False, 512, 32, 8, 16, [8192], 128, 608, torch.float16, 16, 64, 2, 4), - (False, 512, 64, 8, 16, [8192], 128, 608, torch.float16, 16, 128, 2, 4), - (False, 512, 128, 8, 16, [8192], 128, 608, torch.float16, 16, 128, 2, 4), + ( + True, + 1, + 32, + 8, + 8192, + [8192], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 64, + 8, + 8192, + [8192], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 128, + 8, + 8192, + [8192], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 32, + 16, + 1024, + [1024], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 64, + 16, + 1024, + [1024], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 128, + 16, + 1024, + [1024], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 32, + 32, + 2048, + [2048], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 64, + 32, + 2048, + [2048], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + True, + 1, + 128, + 32, + 2048, + [2048], + 128, + 608, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), + ( + False, + 512, + 32, + 8, + 16, + [8192], + 128, + 608, + torch.float16, + 16, + 64, + False, + 2, + 4, + ), + ( + False, + 512, + 64, + 8, + 16, + [8192], + 128, + 608, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), + ( + False, + 512, + 128, + 8, + 16, + [8192], + 128, + 608, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), ], line_arg="provider", line_vals=["triton"], @@ -130,6 +311,7 @@ def bench_lean_attention( init_dtype, BLOCK_M, BLOCK_N, + RAGGED_BATCH, waves_per_eu, num_warps, provider, @@ -196,6 +378,7 @@ def bench_lean_attention( causal, batch, sm_scale, + RAGGED_BATCH, num_warps, waves_per_eu, ) diff --git a/op_tests/triton_tests/test_la.py b/op_tests/triton_tests/test_la.py index 504a2a8fdf..b4469e56a3 100644 --- a/op_tests/triton_tests/test_la.py +++ b/op_tests/triton_tests/test_la.py @@ -101,24 +101,159 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): @pytest.mark.parametrize( - "causal, batch, hq, hk, n_ctx_q, n_ctx, d, total_programs, init_dtype, BLOCK_M, BLOCK_N, waves_per_eu, num_warps ", + "causal, batch, hq, hk, n_ctx_q, n_ctx, d, total_programs, init_dtype, BLOCK_M, BLOCK_N, RAGGED_BATCH, waves_per_eu, num_warps ", [ - (False, 2, 64, 64, 128, [65536, 65536], 128, 304, torch.float16, 128, 64, 1, 4), - (False, 2, 64, 64, 16, [65536, 65536], 128, 912, torch.float16, 16, 128, 3, 4), - (False, 1, 64, 64, 16, [131072], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 64, 64, 16, [262144], 64, 912, torch.float16, 16, 64, 2, 4), - (False, 1, 64, 64, 16, [524288], 64, 912, torch.float16, 16, 64, 2, 4), - (False, 2, 96, 96, 16, [32768, 32768], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 96, 96, 16, [65536], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 96, 96, 16, [131072], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 96, 96, 16, [262144], 64, 912, torch.float16, 16, 64, 2, 4), - (False, 1, 96, 96, 16, [524288], 16, 912, torch.float16, 16, 256, 1, 4), # - (False, 1, 96, 96, 16, [1048576], 16, 912, torch.float16, 16, 256, 1, 4), # - (False, 1, 128, 128, 16, [32768], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 128, 128, 16, [65536], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 128, 128, 16, [131072], 128, 912, torch.float16, 16, 128, 2, 4), - (False, 1, 128, 128, 16, [262144], 64, 912, torch.float16, 16, 64, 2, 4), - (False, 1, 128, 128, 16, [524288], 16, 912, torch.float16, 16, 256, 1, 4), # + ( + False, + 2, + 64, + 64, + 128, + [65536, 65536], + 128, + 304, + torch.float16, + 128, + 64, + False, + 1, + 4, + ), + ( + False, + 2, + 64, + 64, + 16, + [65536, 65536], + 128, + 912, + torch.float16, + 16, + 128, + False, + 3, + 4, + ), + (False, 1, 64, 64, 16, [131072], 128, 912, torch.float16, 16, 128, False, 2, 4), + (False, 1, 64, 64, 16, [262144], 64, 912, torch.float16, 16, 64, False, 2, 4), + (False, 1, 64, 64, 16, [524288], 64, 912, torch.float16, 16, 64, False, 2, 4), + ( + False, + 2, + 96, + 96, + 16, + [32768, 32768], + 128, + 912, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), + (False, 1, 96, 96, 16, [65536], 128, 912, torch.float16, 16, 128, False, 2, 4), + (False, 1, 96, 96, 16, [131072], 128, 912, torch.float16, 16, 128, False, 2, 4), + (False, 1, 96, 96, 16, [262144], 64, 912, torch.float16, 16, 64, False, 2, 4), + ( + False, + 1, + 96, + 96, + 16, + [524288], + 16, + 912, + torch.float16, + 16, + 256, + False, + 1, + 4, + ), # + ( + False, + 1, + 96, + 96, + 16, + [1048576], + 16, + 912, + torch.float16, + 16, + 256, + False, + 1, + 4, + ), # + ( + False, + 1, + 128, + 128, + 16, + [32768], + 128, + 912, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), + ( + False, + 1, + 128, + 128, + 16, + [65536], + 128, + 912, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), + ( + False, + 1, + 128, + 128, + 16, + [131072], + 128, + 912, + torch.float16, + 16, + 128, + False, + 2, + 4, + ), + (False, 1, 128, 128, 16, [262144], 64, 912, torch.float16, 16, 64, False, 2, 4), + ( + False, + 1, + 128, + 128, + 16, + [524288], + 16, + 912, + torch.float16, + 16, + 256, + False, + 1, + 4, + ), # ( False, 3, @@ -131,6 +266,7 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): torch.float16, 16, 128, + True, 2, 4, ), @@ -146,6 +282,7 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): torch.float16, 16, 64, + True, 2, 4, ), @@ -161,10 +298,26 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): torch.float16, 128, 64, + False, 2, 4, ), # Causal=1, - (True, 2, 64, 64, 2048, [2048, 2048], 128, 304, torch.float16, 128, 64, 2, 4), + ( + True, + 2, + 64, + 64, + 2048, + [2048, 2048], + 128, + 304, + torch.float16, + 128, + 64, + False, + 2, + 4, + ), # These test cases fail: # (True, 2, 64, 2048, [2048, 2048], 128, 304, torch.float16, 128, 64, 2, 4), # (True, 1, 64, 4096, [4096], 128, 304, torch.float16, 128, 16, 3, 4), @@ -173,6 +326,7 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): ) def test_persistent_lean_attention( request, + causal, batch, hq, hk, @@ -183,9 +337,9 @@ def test_persistent_lean_attention( init_dtype, BLOCK_M, BLOCK_N, + RAGGED_BATCH, waves_per_eu, num_warps, - causal, ): torch.cuda.empty_cache() # Helps avoid hangs in large tests @@ -251,6 +405,7 @@ def test_persistent_lean_attention( causal, batch, sm_scale, + RAGGED_BATCH, num_warps, waves_per_eu, ) @@ -276,6 +431,7 @@ def test_persistent_lean_attention( @pytest.mark.parametrize("d", [32]) @pytest.mark.parametrize("causal", [(True), (False)]) @pytest.mark.parametrize("init_dtype", [torch.float16]) +@pytest.mark.parametrize("RAGGED_BATCH", [False]) def test_persistent_lean_attention_outer( batch, h, @@ -284,6 +440,7 @@ def test_persistent_lean_attention_outer( d, init_dtype, causal, + RAGGED_BATCH, ): torch.manual_seed(20) @@ -325,6 +482,7 @@ def test_persistent_lean_attention_outer( batch, sm_scale, causal=causal, + RAGGED_BATCH=RAGGED_BATCH, config=config, ) @@ -368,21 +526,21 @@ def print_mismatches(ref_out, la_out, atol=1e-8, rtol=1e-5): def main(): - # (True, 2, 64, 8, 16384, [16384, 16384], 128, 608, torch.float16, 128, 64, 2, 4), - batch = 1024 + batch = 3 causal = False - hq = 32 - hk = 8 + hq = 128 + hk = 128 n_ctx_q = 16 - n_ctx = [8192] * batch # [16384] #[8192] + n_ctx = [4096, 32768, 65536] # [131072] * batch # [16384] #[8192] d = 128 - total_programs = 304 + total_programs = 912 init_dtype = torch.float16 BLOCK_M = 16 - BLOCK_N = 64 + BLOCK_N = 128 XCD_REMAP = True waves_per_eu = 2 num_warps = 4 + RAGGED_BATCH = True assert batch == len(n_ctx) try: @@ -436,6 +594,7 @@ def main(): causal, batch, sm_scale, + RAGGED_BATCH, num_warps, waves_per_eu, )