diff --git a/third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py b/third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py index 9c51113fd1..7f1d673212 100644 --- a/third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py +++ b/third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py @@ -4,8 +4,8 @@ import triton import triton.language as tl import triton.language.extra.tlx as tlx -from triton.tools.tensor_descriptor import TensorDescriptor from triton._internal_testing import is_blackwell +from triton.tools.tensor_descriptor import TensorDescriptor DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -31,9 +31,17 @@ def _host_descriptor_pre_hook(nargs): configs = [ triton.Config( { - 'BLOCK_M': 256, 'BLOCK_N': 128, 'NUM_BUFFERS_Q': 1, 'NUM_BUFFERS_KV': 3, 'NUM_BUFFERS_QK': 1, - 'NUM_MMA_GROUPS': 2 - }, num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook), + "BLOCK_M": 256, + "BLOCK_N": 128, + "NUM_BUFFERS_Q": 1, + "NUM_BUFFERS_KV": 3, + "NUM_BUFFERS_QK": 1, + "NUM_MMA_GROUPS": 2, + }, + num_stages=0, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), ] @@ -45,19 +53,133 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV): @triton.jit -def _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M): +def _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + # First part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Second part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + else: + tl.static_assert(STAGE == 3) + # Maps to STAGE=1 in _get_fused_loop_bounds + lo, hi = 0, N_CTX + return lo, hi + + +@triton.jit +def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + return 0, N_CTX + else: + tl.static_assert(STAGE == 3) + return 0, (start_m + 1) * BLOCK_M + + +@triton.jit +def _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE: tl.constexpr): start_m = tile_idx % n_tile_num off_hz = tile_idx // n_tile_num off_z = off_hz // H off_h = off_hz % H offset_y = off_z * (N_CTX * H) + off_h * N_CTX qo_offset_y = offset_y + start_m * BLOCK_M - lo, hi = 0, N_CTX + lo, hi = _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) kv_offset_y = offset_y + lo return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y -@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) +@triton.jit +def _mask_scalar(qk, col_limit_right, s, i): + col_lim_right_s = col_limit_right - s + col_lim_right_cur = max(col_lim_right_s, 0) + mask = -1 << col_lim_right_cur + mask_i_bit = (mask & (1 << i)) == 0 + return tl.where(mask_i_bit, qk, -float("inf")) + + +@triton.jit +def _apply_causal_mask(qk, col_limit_right, HEAD_DIM: tl.constexpr): + # Apply causal mask via a bitmask calculated for each block of 16 elements. + # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level. + # Credit to Tri Dao, + # https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78 + # + # NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions + # that processes one element of qk at a time. This improves ptxas's resulting SASS. + offs_n = tl.arange(0, HEAD_DIM)[None, :] + s = offs_n & ~0xF + i = offs_n & 0xF + return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i) + + +@triton.jit +def _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_BUFFERS_QK: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + STAGE: tl.constexpr, +): + lo, hi = _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) + + for start_n in tl.range(lo, hi, BLOCK_N): + qk_bufIdx, qk_phase = _get_bufidx_phase(accum_cnt_qk, NUM_BUFFERS_QK) + qk_bufIdx += cid * NUM_BUFFERS_QK + + tlx.barrier_wait(tlx.local_view(qk_fulls, qk_bufIdx), qk_phase) + qk = tlx.local_load(tlx.local_view(qk_tiles, qk_bufIdx)) + + if STAGE == 2: + col_limit_right = (offs_m - start_n + 1)[:, None] + qk = _apply_causal_mask(qk, col_limit_right, HEAD_DIM) + + # compute m_i, p in registers + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + tlx.barrier_wait(tlx.local_view(alpha_empties, qk_bufIdx), qk_phase ^ 1) + # Use alpha[0] for cid=0, and alpha[HEAD_DIM * NUM_BUFFERS_QK] for cid=1 + tlx.local_store(tlx.local_view(alpha_tiles, cid * HEAD_DIM * NUM_BUFFERS_QK), alpha[:, None]) + tlx.barrier_arrive(tlx.local_view(alpha_fulls, qk_bufIdx)) + + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + p = p.to(out_dtype) + + # prepare p for the v dot + # Use p[1] for cid=0, and p[3] for cid=1 + p_bufIdx = 1 + cid * NUM_MMA_GROUPS * NUM_BUFFERS_QK + tlx.local_store(tlx.local_view(p_tiles, p_bufIdx), p) + tlx.barrier_arrive(tlx.local_view(p_fulls, qk_bufIdx)) + + l_i = l_i * alpha + l_ij + m_i = m_ij + accum_cnt_qk += 1 + + return m_i, l_i, accum_cnt_qk + + +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "STAGE"]) @triton.jit def _attn_fwd_ws(sm_scale, M, # Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # @@ -65,6 +187,7 @@ def _attn_fwd_ws(sm_scale, M, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # NUM_BUFFERS_Q: tl.constexpr, # NUM_BUFFERS_KV: tl.constexpr, # NUM_BUFFERS_QK: tl.constexpr, # @@ -103,14 +226,34 @@ def _attn_fwd_ws(sm_scale, M, # qk_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tl.float32, NUM_MMA_GROUPS, tlx.storage_kind.tmem) # Shared buffer for QK, P and Alpha, l, and m. # Alpha/l/m lives in the lower half of qk_buf, and P lives in the upper half. - p_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tlx.dtype_of(desc_v), NUM_MMA_GROUPS * 2, - tlx.storage_kind.tmem, reuse=qk_tiles) - alpha_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - l_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - m_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) + p_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, HEAD_DIM), + tlx.dtype_of(desc_v), + NUM_MMA_GROUPS * 2, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + alpha_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + l_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + m_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) acc_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tl.float32, NUM_MMA_GROUPS, tlx.storage_kind.tmem) @@ -132,7 +275,7 @@ def _attn_fwd_ws(sm_scale, M, # for i in range(0, tiles_per_sm): # initialize offsets start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets( - tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) for _ in tl.range(lo, hi, BLOCK_N): _, phase = _get_bufidx_phase(accum_cnt, 1) for cid in tl.range(0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS): @@ -157,7 +300,7 @@ def _attn_fwd_ws(sm_scale, M, # tlx.barrier_arrive(qk_empties[cid]) m = tlx.local_load(m_tiles[cid * HEAD_DIM + 2]) m += tl.math.log2(l) - offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) + offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT)) m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, tl.reshape(m, [BLOCK_M_SPLIT])) @@ -175,44 +318,68 @@ def _attn_fwd_ws(sm_scale, M, # for i in range(0, tiles_per_sm): # initialize offsets start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets( - tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32) qk_scale = sm_scale qk_scale *= 1.44269504 # 1/log(2) + out_dtype = tlx.dtype_of(desc_v) cid = tlx.async_task_replica_id() - for _ in tl.range(lo, hi, BLOCK_N): - _, qk_phase = _get_bufidx_phase(accum_cnt_qk, 1) - tlx.barrier_wait(qk_fulls[cid], qk_phase) - qk = tlx.local_load(qk_tiles[cid]) - - # compute m_i, p in registers - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - - # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) - tlx.barrier_wait(alpha_empties[cid], qk_phase ^ 1) - # Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1 - tlx.local_store(alpha_tiles[cid * HEAD_DIM], alpha[:, None]) - tlx.barrier_arrive(alpha_fulls[cid]) - - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - p = p.to(tlx.dtype_of(desc_v)) - - # prepare p for the v dot - # Use p[1] for cid=0, and p[3] for cid=1 - p_bufIdx = 1 + cid * NUM_MMA_GROUPS - tlx.local_store(p_tiles[p_bufIdx], p) - tlx.barrier_arrive(p_fulls[cid]) + offs_m = start_m * BLOCK_M + ((cid * BLOCK_M_SPLIT) + tl.arange(0, BLOCK_M_SPLIT)) + if STAGE & 1: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_BUFFERS_QK, + NUM_MMA_GROUPS, + STAGE=4 - STAGE, + ) - l_i = l_i * alpha + l_ij - m_i = m_ij - accum_cnt_qk += 1 + if STAGE & 2: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_BUFFERS_QK, + NUM_MMA_GROUPS, + STAGE=2, + ) # prepare l_i for the epilog # Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM] @@ -229,7 +396,7 @@ def _attn_fwd_ws(sm_scale, M, # for j in range(0, tiles_per_sm): # initialize offsets - _, _, lo, hi, _, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + _, _, lo, hi, _, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) # wait for the Q buffer to be populated by the producer q_bufIdx, q_phase = _get_bufidx_phase(j, NUM_BUFFERS_Q) @@ -303,7 +470,8 @@ def _attn_fwd_ws(sm_scale, M, # accum_cnt_kv = 0 for i in range(0, tiles_per_sm): # initialize offsets - _, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + _, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, + STAGE) # load q: it will stay in SRAM throughout q_bufIdx, q_phase = _get_bufidx_phase(i, NUM_BUFFERS_Q) @@ -349,13 +517,16 @@ def _attn_fwd_ws(sm_scale, M, # class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, sm_scale): + def forward(ctx, q, k, v, sm_scale, causal): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + stage = 3 if causal else 1 + o = torch.empty_like(q) extra_kern_args = {} @@ -364,13 +535,38 @@ def forward(ctx, q, k, v, sm_scale): y_dim = q.shape[0] * q.shape[1] * q.shape[2] dummy_block = [1, 1] - desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_q = TensorDescriptor( + q, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) if q.dtype == torch.float8_e5m2: - desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[HEAD_DIM_K, y_dim], + strides=[q.shape[2], 1], + block_shape=dummy_block, + ) else: - desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_o = TensorDescriptor( + o, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) def alloc_fn(size: int, align: int, _): return torch.empty(size, dtype=torch.int8, device="cuda") @@ -402,6 +598,7 @@ def grid(META): N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + STAGE=stage, # **extra_kern_args, ) @@ -424,7 +621,8 @@ def grid(META): @pytest.mark.parametrize("HEAD_DIM", [128]) @pytest.mark.parametrize("mode", ["fwd"]) @pytest.mark.parametrize("provider", ["triton-fp16"]) -def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): +@pytest.mark.parametrize("causal", [True, False]) +def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16): if mode == "bwd": pytest.skip("Backward pass not supported.") torch.manual_seed(20) @@ -439,11 +637,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): q = q.to(ref_dtype) k = k.to(ref_dtype) v = v.to(ref_dtype) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - p = torch.softmax(p.float(), dim=-1) - p = p.to(ref_dtype) - # p = torch.exp(p) - ref_out = torch.matmul(p, v).half() + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal) # triton implementation if mode == "fwd" and "fp8" in provider: q = q.to(torch.float8_e5m2) @@ -451,7 +645,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) - tri_out = attention(q, k, v, sm_scale).half() + tri_out = attention(q, k, v, sm_scale, causal).half() if mode == "fwd": atol = 3 if "fp8" in provider else 1e-2 torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) @@ -459,8 +653,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): try: - from flash_attn.flash_attn_interface import \ - flash_attn_qkvpacked_func as flash_attn_func + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as flash_attn_func, ) + HAS_FLASH = True except BaseException: HAS_FLASH = False @@ -512,7 +707,12 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, mode, provider, device=DEVI ms = triton.testing.do_bench(fn) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn( + (BATCH, N_CTX, 3, H, HEAD_DIM), + dtype=dtype, + device=device, + requires_grad=True, + ) fn = lambda: flash_attn_func(qkv) if mode == "bwd": o = fn() diff --git a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py index 2390915eb5..93636f44b0 100644 --- a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py +++ b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py @@ -4,8 +4,8 @@ import triton import triton.language as tl import triton.language.extra.tlx as tlx -from triton.tools.tensor_descriptor import TensorDescriptor from triton._internal_testing import is_blackwell +from triton.tools.tensor_descriptor import TensorDescriptor DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -95,14 +95,38 @@ def _fma_f32x2(a, b, c): @triton.jit -def _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M): +def _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + # First part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Second part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + else: + tl.static_assert(STAGE == 3) + # Maps to STAGE=1 in _get_fused_loop_bounds + lo, hi = 0, N_CTX + return lo, hi + + +@triton.jit +def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + return 0, N_CTX + else: + tl.static_assert(STAGE == 3) + return 0, (start_m + 1) * BLOCK_M + + +@triton.jit +def _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE: tl.constexpr): start_m = tile_idx % n_tile_num off_hz = tile_idx // n_tile_num off_z = off_hz // H off_h = off_hz % H offset_y = off_z * (N_CTX * H) + off_h * N_CTX qo_offset_y = offset_y + start_m * BLOCK_M - lo, hi = 0, N_CTX + lo, hi = _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) kv_offset_y = offset_y + lo return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y @@ -127,7 +151,99 @@ def _join_n(xs): return x -@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) +@triton.jit +def _mask_scalar(qk, col_limit_right, s, i): + col_lim_right_s = col_limit_right - s + col_lim_right_cur = max(col_lim_right_s, 0) + mask = -1 << col_lim_right_cur + mask_i_bit = (mask & (1 << i)) == 0 + return tl.where(mask_i_bit, qk, -float("inf")) + + +@triton.jit +def _apply_causal_mask(qk, col_limit_right, HEAD_DIM: tl.constexpr): + # Apply causal mask via a bitmask calculated for each block of 16 elements. + # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level. + # Credit to Tri Dao, + # https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78 + # + # NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions + # that processes one element of qk at a time. This improves ptxas's resulting SASS. + offs_n = tl.arange(0, HEAD_DIM)[None, :] + s = offs_n & ~0xF + i = offs_n & 0xF + return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i) + + +@triton.jit +def _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_MMA_SLICES: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + STAGE: tl.constexpr, +): + lo, hi = _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) + + for start_n in tl.range(lo, hi, BLOCK_N): + _, qk_phase = _get_bufidx_phase(accum_cnt_qk, 1) + tlx.barrier_wait(tlx.local_view(qk_fulls, cid), qk_phase) + qk = tlx.local_load(tlx.local_view(qk_tiles, cid)) + + if STAGE == 2: + col_limit_right = (offs_m - start_n + 1)[:, None] + qk = _apply_causal_mask(qk, col_limit_right, HEAD_DIM) + + # compute m_i, p in registers + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + tlx.barrier_wait(tlx.local_view(alpha_empties, cid), qk_phase ^ 1) + # Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1 + tlx.local_store(tlx.local_view(alpha_tiles, cid * HEAD_DIM), alpha[:, None]) + tlx.barrier_arrive(tlx.local_view(alpha_fulls, cid)) + + qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) + qks = _split_n(qk, NUM_MMA_SLICES) + ps = () + for slice_id in tl.static_range(0, NUM_MMA_SLICES): + # prepare p for the v dot + # Use p[NUM_MMA_SLICES + slice_id] for cid=0, and + # p[NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id] for cid=1 + p_bufIdx = cid * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id + p_i = tl.math.exp2(qks[slice_id]) + tlx.local_store(tlx.local_view(p_tiles, p_bufIdx), p_i.to(out_dtype)) + tlx.barrier_arrive(tlx.local_view(p_fulls, slice_id + cid * NUM_MMA_SLICES)) + ps = ps + (p_i, ) + + p = _join_n(ps) + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + m_i = m_ij + accum_cnt_qk += 1 + + return m_i, l_i, accum_cnt_qk + + +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "STAGE"]) @triton.jit def _attn_fwd_ws(sm_scale, M, # Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # @@ -135,6 +251,7 @@ def _attn_fwd_ws(sm_scale, M, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # NUM_BUFFERS_Q: tl.constexpr, # NUM_BUFFERS_KV: tl.constexpr, # NUM_BUFFERS_QK: tl.constexpr, # @@ -192,12 +309,27 @@ def _attn_fwd_ws(sm_scale, M, # tlx.storage_kind.tmem, reuse=qk_tiles, ) - alpha_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - l_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - m_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) + alpha_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + l_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + m_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + BLOCK_N * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) acc_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tl.float32, NUM_MMA_GROUPS, tlx.storage_kind.tmem) @@ -219,7 +351,7 @@ def _attn_fwd_ws(sm_scale, M, # for i in range(0, tiles_per_sm): # initialize offsets start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets( - tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) for _ in tl.range(lo, hi, BLOCK_N): _, phase = _get_bufidx_phase(accum_cnt, 1) for cid in tl.static_range(0, NUM_MMA_GROUPS): @@ -251,7 +383,7 @@ def _attn_fwd_ws(sm_scale, M, # tlx.barrier_arrive(qk_empties[cid]) m = tlx.local_load(m_tiles[cid * HEAD_DIM + 2]) m += tl.math.log2(l) - offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) + offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT)) m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, tl.reshape(m, [BLOCK_M_SPLIT])) @@ -283,48 +415,68 @@ def _attn_fwd_ws(sm_scale, M, # for i in range(0, tiles_per_sm): # initialize offsets start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets( - tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32) qk_scale = sm_scale qk_scale *= 1.44269504 # 1/log(2) + out_dtype = tlx.dtype_of(desc_v) cid = tlx.async_task_replica_id() - for _ in tl.range(lo, hi, BLOCK_N): - _, qk_phase = _get_bufidx_phase(accum_cnt_qk, 1) - tlx.barrier_wait(qk_fulls[cid], qk_phase) - qk = tlx.local_load(qk_tiles[cid]) - - # compute m_i, p in registers - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - - # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) - tlx.barrier_wait(alpha_empties[cid], qk_phase ^ 1) - # Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1 - tlx.local_store(alpha_tiles[cid * HEAD_DIM], alpha[:, None]) - tlx.barrier_arrive(alpha_fulls[cid]) - - qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) - qks = _split_n(qk, NUM_MMA_SLICES) - ps = () - for slice_id in tl.static_range(0, NUM_MMA_SLICES): - # prepare p for the v dot - # Use p[NUM_MMA_SLICES + slice_id] for cid=0, and - # p[NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id] for cid=1 - p_bufIdx = cid * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id - p_i = tl.math.exp2(qks[slice_id]) - tlx.local_store(p_tiles[p_bufIdx], p_i.to(tlx.dtype_of(desc_v))) - tlx.barrier_arrive(p_fulls[slice_id + cid * NUM_MMA_SLICES]) - ps = ps + (p_i, ) - - p = _join_n(ps) - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - m_i = m_ij - accum_cnt_qk += 1 + offs_m = (start_m * BLOCK_M) + ((cid * BLOCK_M_SPLIT) + tl.arange(0, BLOCK_M_SPLIT)) + if STAGE & 1: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_MMA_SLICES, + NUM_MMA_GROUPS, + STAGE=4 - STAGE, + ) + + if STAGE & 2: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_MMA_SLICES, + NUM_MMA_GROUPS, + STAGE=2, + ) # prepare l_i for the epilog # Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM] @@ -341,7 +493,7 @@ def _attn_fwd_ws(sm_scale, M, # for j in range(0, tiles_per_sm): # initialize offsets - _, _, lo, hi, _, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + _, _, lo, hi, _, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) q_bufIdx, q_phase = _get_bufidx_phase(j, NUM_BUFFERS_Q) k_bufIdx, k_phase = _get_bufidx_phase(accum_cnt_kv, NUM_BUFFERS_KV) @@ -432,9 +584,9 @@ def _attn_fwd_ws(sm_scale, M, # [BLOCK_N * slice_id // NUM_MMA_SLICES, 0], [BLOCK_N // NUM_MMA_SLICES, HEAD_DIM], ) - p_bufIdx = 1 * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id + p_bufIdx = (1 * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id) use_acc = acc1_init if slice_id == 0 else True - mBarriers = [kv_empties[v_bufIdx_prev]] if slice_id == NUM_MMA_SLICES - 1 else [] + mBarriers = ([kv_empties[v_bufIdx_prev]] if slice_id == NUM_MMA_SLICES - 1 else []) tlx.async_dot( p_tiles[p_bufIdx], kv_slice, @@ -489,9 +641,9 @@ def _attn_fwd_ws(sm_scale, M, # [BLOCK_N * slice_id // NUM_MMA_SLICES, 0], [BLOCK_N // NUM_MMA_SLICES, HEAD_DIM], ) - p_bufIdx = 1 * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id + p_bufIdx = (1 * NUM_MMA_GROUPS * NUM_MMA_SLICES + NUM_MMA_SLICES + slice_id) use_acc = acc1_init if slice_id == 0 else True - mBarriers = [acc_empties[1], kv_empties[v_bufIdx]] if slice_id == NUM_MMA_SLICES - 1 else [] + mBarriers = ([acc_empties[1], kv_empties[v_bufIdx]] if slice_id == NUM_MMA_SLICES - 1 else []) tlx.async_dot( p_tiles[p_bufIdx], kv_slice, @@ -509,7 +661,8 @@ def _attn_fwd_ws(sm_scale, M, # accum_cnt_kv = 0 for i in range(0, tiles_per_sm): # initialize offsets - _, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + _, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, + STAGE) # load q0 q_bufIdx, q_phase = _get_bufidx_phase(i, NUM_BUFFERS_Q) @@ -581,7 +734,7 @@ def _attn_fwd_ws(sm_scale, M, # # initialize offsets for i in range(0, tiles_per_sm): # initialize offsets - _, _, _, _, qo_offset_y, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M) + _, _, _, _, qo_offset_y, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE) _, phase = _get_bufidx_phase(i, 1) for cid in tl.static_range(0, NUM_MMA_GROUPS): tlx.barrier_wait(o_fulls[cid], phase) @@ -904,8 +1057,13 @@ def _attn_bwd_ws( tlx.storage_kind.tmem, ) - dq_tiles = tlx.local_alloc((BLOCK_M1, HEAD_DIM), tl.float32, NUM_BUFFERS_TMEM, tlx.storage_kind.tmem, - reuse=dp_tiles) + dq_tiles = tlx.local_alloc( + (BLOCK_M1, HEAD_DIM), + tl.float32, + NUM_BUFFERS_TMEM, + tlx.storage_kind.tmem, + reuse=dp_tiles, + ) dv_tiles = tlx.local_alloc((BLOCK_N1, HEAD_DIM), tl.float32, NUM_BUFFERS_KV, tlx.storage_kind.tmem) dk_tiles = tlx.local_alloc((BLOCK_N1, HEAD_DIM), tl.float32, NUM_BUFFERS_KV, tlx.storage_kind.tmem) @@ -1199,8 +1357,12 @@ def _attn_bwd_ws( # Load K kv_buf_id, _ = _get_bufidx_phase(0, NUM_BUFFERS_KV) tlx.barrier_expect_bytes(k_fulls[kv_buf_id], 2 * BLOCK_N1 * HEAD_DIM) # float16 - tlx.async_descriptor_load(desc_k, k_tiles[kv_buf_id], [(off_bh + start_n).to(tl.int32), 0], - k_fulls[kv_buf_id]) + tlx.async_descriptor_load( + desc_k, + k_tiles[kv_buf_id], + [(off_bh + start_n).to(tl.int32), 0], + k_fulls[kv_buf_id], + ) # Load Q curr_m = start_m @@ -1209,19 +1371,32 @@ def _attn_bwd_ws( q_buf_id, q_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_Q) tlx.barrier_wait(q_empties[q_buf_id], q_phase ^ 1) tlx.barrier_expect_bytes(q_fulls[q_buf_id], 2 * BLOCK_M1 * HEAD_DIM) - tlx.async_descriptor_load(desc_q, q_tiles[q_buf_id], [(off_bh + curr_m).to(tl.int32), 0], q_fulls[q_buf_id]) + tlx.async_descriptor_load( + desc_q, + q_tiles[q_buf_id], + [(off_bh + curr_m).to(tl.int32), 0], + q_fulls[q_buf_id], + ) # Load V tlx.barrier_expect_bytes(v_fulls[kv_buf_id], 2 * BLOCK_N1 * HEAD_DIM) # float16 - tlx.async_descriptor_load(desc_v, v_tiles[kv_buf_id], [(off_bh + start_n).to(tl.int32), 0], - v_fulls[kv_buf_id]) + tlx.async_descriptor_load( + desc_v, + v_tiles[kv_buf_id], + [(off_bh + start_n).to(tl.int32), 0], + v_fulls[kv_buf_id], + ) # Load dO do_buf_id, do_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DO) tlx.barrier_wait(do_empties[do_buf_id], do_phase ^ 1) tlx.barrier_expect_bytes(do_fulls[do_buf_id], 2 * BLOCK_M1 * HEAD_DIM) - tlx.async_descriptor_load(desc_do, do_tiles[do_buf_id], [(off_bh + curr_m).to(tl.int32), 0], - do_fulls[do_buf_id]) + tlx.async_descriptor_load( + desc_do, + do_tiles[do_buf_id], + [(off_bh + curr_m).to(tl.int32), 0], + do_fulls[do_buf_id], + ) curr_m += step_m for blk_idx in range(1, num_steps): @@ -1230,27 +1405,38 @@ def _attn_bwd_ws( # Load Q tlx.barrier_wait(q_empties[q_buf_id], q_phase ^ 1) tlx.barrier_expect_bytes(q_fulls[q_buf_id], 2 * BLOCK_M1 * HEAD_DIM) - tlx.async_descriptor_load(desc_q, q_tiles[q_buf_id], [(off_bh + curr_m).to(tl.int32), 0], - q_fulls[q_buf_id]) + tlx.async_descriptor_load( + desc_q, + q_tiles[q_buf_id], + [(off_bh + curr_m).to(tl.int32), 0], + q_fulls[q_buf_id], + ) # Load dO tlx.barrier_wait(do_empties[do_buf_id], do_phase ^ 1) tlx.barrier_expect_bytes(do_fulls[do_buf_id], 2 * BLOCK_M1 * HEAD_DIM) - tlx.async_descriptor_load(desc_do, do_tiles[do_buf_id], [(off_bh + curr_m).to(tl.int32), 0], - do_fulls[do_buf_id]) + tlx.async_descriptor_load( + desc_do, + do_tiles[do_buf_id], + [(off_bh + curr_m).to(tl.int32), 0], + do_fulls[do_buf_id], + ) curr_m += step_m class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, sm_scale): + def forward(ctx, q, k, v, sm_scale, causal): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + stage = 3 if causal else 1 + o = torch.empty_like(q) extra_kern_args = {} @@ -1259,13 +1445,38 @@ def forward(ctx, q, k, v, sm_scale): y_dim = q.shape[0] * q.shape[1] * q.shape[2] dummy_block = [1, 1] - desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_q = TensorDescriptor( + q, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) if q.dtype == torch.float8_e5m2: - desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[HEAD_DIM_K, y_dim], + strides=[q.shape[2], 1], + block_shape=dummy_block, + ) else: - desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_o = TensorDescriptor( + o, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) def alloc_fn(size: int, align: int, _): return torch.empty(size, dtype=torch.int8, device="cuda") @@ -1286,13 +1497,20 @@ def grid(META): ctx.grid = grid _attn_fwd_ws[grid]( - sm_scale, M, # - q.shape[0], q.shape[1], # - desc_q, desc_k, desc_v, desc_o, # + sm_scale, + M, # + q.shape[0], + q.shape[1], # + desc_q, + desc_k, + desc_v, + desc_o, # N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # FP8_OUTPUT=q.dtype == torch.float8_e5m2, # - **extra_kern_args) + STAGE=stage, # + **extra_kern_args, + ) ctx.save_for_backward(q, k, v, o, M) ctx.sm_scale = sm_scale @@ -1326,20 +1544,48 @@ def backward(ctx, do): dummy_block = [1, 1] HEAD_DIM = ctx.HEAD_DIM - desc_k = TensorDescriptor(arg_k, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_v = TensorDescriptor(v, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_q = TensorDescriptor(q, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_do = TensorDescriptor(do, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_dq = TensorDescriptor(dq, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_dk = TensorDescriptor(dk, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) - desc_dv = TensorDescriptor(dv, shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=dummy_block) + desc_k = TensorDescriptor( + arg_k, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + v, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_q = TensorDescriptor( + q, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_do = TensorDescriptor( + do, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_dq = TensorDescriptor( + dq, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_dk = TensorDescriptor( + dk, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) + desc_dv = TensorDescriptor( + dv, + shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM], + strides=[HEAD_DIM, 1], + block_shape=dummy_block, + ) def alloc_fn(size: int, align: int, _): return torch.empty(size, dtype=torch.int8, device="cuda") @@ -1347,9 +1593,11 @@ def alloc_fn(size: int, align: int, _): triton.set_allocator(alloc_fn) def grid(meta): - return (triton.cdiv(N_CTX, meta['BLOCK_N1']), # tiles along N (K/V) - 1, # (or cdiv over M if you need) - BATCH * N_HEAD) # batch*heads + return ( + triton.cdiv(N_CTX, meta["BLOCK_N1"]), # tiles along N (K/V) + 1, # (or cdiv over M if you need) + BATCH * N_HEAD, + ) # batch*heads _attn_bwd_ws[grid]( desc_q, desc_k, desc_v, ctx.sm_scale, desc_do, desc_dq, desc_dk, desc_dv, # @@ -1374,9 +1622,10 @@ def grid(meta): @pytest.mark.parametrize("H", [16]) @pytest.mark.parametrize("N_CTX", [1024]) @pytest.mark.parametrize("HEAD_DIM", [128]) -@pytest.mark.parametrize("mode", ["bwd"]) +@pytest.mark.parametrize("mode", ["fwd", "bwd"]) @pytest.mark.parametrize("provider", ["triton-fp16"]) -def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): +@pytest.mark.parametrize("causal", [True, False]) +def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16): torch.manual_seed(20) q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) @@ -1384,16 +1633,16 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): sm_scale = 0.5 # reference implementation ref_dtype = dtype + if mode == "fwd" and not causal: + pytest.skip("Only test fwd with causal") + elif mode == "bwd" and causal: + pytest.skip("Causal not supported for bwd yet") if mode == "fwd" and "fp8" in provider: ref_dtype = torch.float32 q = q.to(ref_dtype) k = k.to(ref_dtype) v = v.to(ref_dtype) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - p = torch.softmax(p.float(), dim=-1) - p = p.to(ref_dtype) - # p = torch.exp(p) - ref_out = torch.matmul(p, v).half() + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal) if mode == "bwd": dout = torch.randn_like(q) ref_out.backward(dout) @@ -1407,7 +1656,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) - tri_out = attention(q, k, v, sm_scale).half() + tri_out = attention(q, k, v, sm_scale, causal).half() if mode == "fwd": atol = 3 if "fp8" in provider else 1e-2 torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) @@ -1422,7 +1671,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): rtol = 0.0 # Relative tolerance workaround for known hardware limitation of CDNA2 GPU. # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + if (torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a"): rtol = 1e-2 torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) @@ -1430,8 +1679,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): try: - from flash_attn.flash_attn_interface import \ - flash_attn_qkvpacked_func as flash_attn_func + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as flash_attn_func, ) + HAS_FLASH = True except BaseException: HAS_FLASH = False @@ -1484,7 +1734,12 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, mode, provider, device=DEVI ms = triton.testing.do_bench(fn) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn( + (BATCH, N_CTX, 3, H, HEAD_DIM), + dtype=dtype, + device=device, + requires_grad=True, + ) fn = lambda: flash_attn_func(qkv) if mode == "bwd": o = fn() diff --git a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py index bbc914afc0..19546518ab 100644 --- a/third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py +++ b/third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py @@ -4,8 +4,8 @@ import triton import triton.language as tl import triton.language.extra.tlx as tlx -from triton.tools.tensor_descriptor import TensorDescriptor from triton._internal_testing import is_blackwell +from triton.tools.tensor_descriptor import TensorDescriptor DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -31,8 +31,18 @@ def _host_descriptor_pre_hook(nargs): configs = [ # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'NUM_BUFFERS_KV': 3, 'NUM_BUFFERS_QK': 1, 'NUM_MMA_GROUPS': 1}, # num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'NUM_BUFFERS_KV': 3, 'NUM_BUFFERS_QK': 1, 'NUM_MMA_GROUPS': 2}, - num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "NUM_BUFFERS_KV": 3, + "NUM_BUFFERS_QK": 1, + "NUM_MMA_GROUPS": 2, + }, + num_stages=0, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), ] @@ -44,19 +54,134 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV): @triton.jit -def _compute_offsets(H, N_CTX, BLOCK_M): +def _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + # First part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Second part of STAGE == 3 in _get_fused_loop_bounds + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + else: + tl.static_assert(STAGE == 3) + # Maps to STAGE=1 in _get_fused_loop_bounds + lo, hi = 0, N_CTX + return lo, hi + + +@triton.jit +def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr): + if STAGE == 1: + return 0, N_CTX + else: + tl.static_assert(STAGE == 3) + return 0, (start_m + 1) * BLOCK_M + + +@triton.jit +def _compute_offsets(H, N_CTX, BLOCK_M, STAGE: tl.constexpr): start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H offset_y = off_z * (N_CTX * H) + off_h * N_CTX qo_offset_y = offset_y + start_m * BLOCK_M - lo, hi = 0, N_CTX + lo, hi = _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) kv_offset_y = offset_y + lo return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y -@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) +@triton.jit +def _mask_scalar(qk, col_limit_right, s, i): + col_lim_right_s = col_limit_right - s + col_lim_right_cur = max(col_lim_right_s, 0) + mask = -1 << col_lim_right_cur + mask_i_bit = (mask & (1 << i)) == 0 + return tl.where(mask_i_bit, qk, -float("inf")) + + +@triton.jit +def _apply_causal_mask(qk, col_limit_right, HEAD_DIM: tl.constexpr): + # Apply causal mask via a bitmask calculated for each block of 16 elements. + # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level. + # Credit to Tri Dao, + # https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78 + # + # NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions + # that processes one element of qk at a time. This improves ptxas's resulting SASS. + offs_n = tl.arange(0, HEAD_DIM)[None, :] + s = offs_n & ~0xF + i = offs_n & 0xF + return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i) + + +@triton.jit +def _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_BUFFERS_QK: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + STAGE: tl.constexpr, +): + lo, hi = _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE) + + for start_n in tl.range(lo, hi, BLOCK_N): + + qk_bufIdx, qk_phase = _get_bufidx_phase(accum_cnt_qk, NUM_BUFFERS_QK) + qk_bufIdx += cid * NUM_BUFFERS_QK + + tlx.barrier_wait(tlx.local_view(qk_fulls, qk_bufIdx), qk_phase) + qk = tlx.local_load(tlx.local_view(qk_tiles, qk_bufIdx)) + + if STAGE == 2: + col_limit_right = (offs_m - start_n + 1)[:, None] + qk = _apply_causal_mask(qk, col_limit_right, HEAD_DIM) + + # compute m_i, p in registers + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + tlx.barrier_wait(tlx.local_view(alpha_empties, qk_bufIdx), qk_phase ^ 1) + # Use alpha[0] for cid=0, and alpha[HEAD_DIM * NUM_BUFFERS_QK] for cid=1 + tlx.local_store(tlx.local_view(alpha_tiles, cid * HEAD_DIM * NUM_BUFFERS_QK), alpha[:, None]) + tlx.barrier_arrive(tlx.local_view(alpha_fulls, qk_bufIdx)) + + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + p = p.to(out_dtype) + + # prepare p for the v dot + # Use p[1] for cid=0, and p[3] for cid=1 + p_bufIdx = 1 + cid * NUM_MMA_GROUPS * NUM_BUFFERS_QK + tlx.local_store(tlx.local_view(p_tiles, p_bufIdx), p) + tlx.barrier_arrive(tlx.local_view(p_fulls, qk_bufIdx)) + + l_i = l_i * alpha + l_ij + m_i = m_ij + accum_cnt_qk += 1 + + return m_i, l_i, accum_cnt_qk + + +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "STAGE"]) @triton.jit def _attn_fwd_ws(sm_scale, M, # Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # @@ -64,6 +189,7 @@ def _attn_fwd_ws(sm_scale, M, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # NUM_BUFFERS_KV: tl.constexpr, # NUM_BUFFERS_QK: tl.constexpr, # NUM_MMA_GROUPS: tl.constexpr, # @@ -83,21 +209,49 @@ def _attn_fwd_ws(sm_scale, M, # kv_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV) # allocate TMEM buffers and barriers - qk_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem) + qk_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, HEAD_DIM), + tl.float32, + NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + ) # Shared buffer for QK, P and Alpha, l, and m. # Alpha/l/m lives in the lower half of qk_buf, and P lives in the upper half. - p_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tlx.dtype_of(desc_v), NUM_MMA_GROUPS * NUM_BUFFERS_QK * 2, - tlx.storage_kind.tmem, reuse=qk_tiles) - alpha_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - l_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - m_tiles = tlx.local_alloc((BLOCK_M_SPLIT, 1), tl.float32, HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem, reuse=qk_tiles) - - acc_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK, - tlx.storage_kind.tmem) + p_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, HEAD_DIM), + tlx.dtype_of(desc_v), + NUM_MMA_GROUPS * NUM_BUFFERS_QK * 2, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + alpha_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + l_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + m_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, 1), + tl.float32, + HEAD_DIM * NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + reuse=qk_tiles, + ) + + acc_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, HEAD_DIM), + tl.float32, + NUM_MMA_GROUPS * NUM_BUFFERS_QK, + tlx.storage_kind.tmem, + ) qk_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS * NUM_BUFFERS_QK) p_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS * NUM_BUFFERS_QK) @@ -112,7 +266,7 @@ def _attn_fwd_ws(sm_scale, M, # # correction group with tlx.async_task("default"): # initialize offsets - start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M) + start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M, STAGE) accum_cnt = 0 buf_idx = 0 phase = 0 @@ -142,7 +296,7 @@ def _attn_fwd_ws(sm_scale, M, # l = tlx.local_load(l_tiles[cid * HEAD_DIM * NUM_BUFFERS_QK + 1]) m = tlx.local_load(m_tiles[cid * HEAD_DIM * NUM_BUFFERS_QK + 2]) m += tl.math.log2(l) - offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT) + offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT)) m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, tl.reshape(m, [BLOCK_M_SPLIT])) @@ -155,7 +309,7 @@ def _attn_fwd_ws(sm_scale, M, # # softmax groups with tlx.async_task(num_warps=4, registers=152, replicate=NUM_MMA_GROUPS): # initialize offsets - start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M) + start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M, STAGE) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0 @@ -164,38 +318,61 @@ def _attn_fwd_ws(sm_scale, M, # qk_scale *= 1.44269504 # 1/log(2) accum_cnt_qk = 0 + out_dtype = tlx.dtype_of(desc_v) + cid = tlx.async_task_replica_id() - for _ in tl.range(lo, hi, BLOCK_N): - qk_bufIdx, qk_phase = _get_bufidx_phase(accum_cnt_qk, NUM_BUFFERS_QK) - qk_bufIdx += cid * NUM_BUFFERS_QK - - tlx.barrier_wait(qk_fulls[qk_bufIdx], qk_phase) - qk = tlx.local_load(qk_tiles[qk_bufIdx]) - - # compute m_i, p in registers - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - - # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) - tlx.barrier_wait(alpha_empties[qk_bufIdx], qk_phase ^ 1) - # Use alpha[0] for cid=0, and alpha[HEAD_DIM * NUM_BUFFERS_QK] for cid=1 - tlx.local_store(alpha_tiles[cid * HEAD_DIM * NUM_BUFFERS_QK], alpha[:, None]) - tlx.barrier_arrive(alpha_fulls[qk_bufIdx]) - - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - p = p.to(tlx.dtype_of(desc_v)) - - # prepare p for the v dot - # Use p[1] for cid=0, and p[3] for cid=1 - p_bufIdx = 1 + cid * NUM_MMA_GROUPS * NUM_BUFFERS_QK - tlx.local_store(p_tiles[p_bufIdx], p) - tlx.barrier_arrive(p_fulls[qk_bufIdx]) - - l_i = l_i * alpha + l_ij - m_i = m_ij - accum_cnt_qk += 1 + offs_m = (start_m * BLOCK_M) + ((cid * BLOCK_M_SPLIT) + tl.arange(0, BLOCK_M_SPLIT)) + if STAGE & 1: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_BUFFERS_QK, + NUM_MMA_GROUPS, + STAGE=4 - STAGE, + ) + + if STAGE & 2: + m_i, l_i, accum_cnt_qk = _softmax_inner_loop( + qk_fulls, + qk_tiles, + p_fulls, + p_tiles, + alpha_empties, + alpha_fulls, + alpha_tiles, + cid, + accum_cnt_qk, + qk_scale, + offs_m, + m_i, + l_i, + start_m, + N_CTX, + out_dtype, + BLOCK_M, + BLOCK_N, + HEAD_DIM, + NUM_BUFFERS_QK, + NUM_MMA_GROUPS, + STAGE=2, + ) # prepare l_i for the epilog # Use l[1]/l[1+HEAD_DIM * NUM_BUFFERS_QK] and m[2][2 + HEAD_DIM * NUM_BUFFERS_QK] @@ -206,7 +383,7 @@ def _attn_fwd_ws(sm_scale, M, # # mma group with tlx.async_task(num_warps=1, registers=24): - _, _, lo, hi, _, _ = _compute_offsets(H, N_CTX, BLOCK_M) + _, _, lo, hi, _, _ = _compute_offsets(H, N_CTX, BLOCK_M, STAGE) # loop over k, v and update accumulator accum_cnt_kv = 0 @@ -324,7 +501,7 @@ def _attn_fwd_ws(sm_scale, M, # # load with tlx.async_task(num_warps=1, registers=24): # initialize offsets - start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M) + start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(H, N_CTX, BLOCK_M, STAGE) # load q0 tlx.barrier_expect_bytes(q_fulls[0], 2 * BLOCK_M_SPLIT * HEAD_DIM) # float16 @@ -390,13 +567,16 @@ def _attn_fwd_ws(sm_scale, M, # class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, sm_scale): + def forward(ctx, q, k, v, sm_scale, causal): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + stage = 3 if causal else 1 + o = torch.empty_like(q) extra_kern_args = {} @@ -405,13 +585,38 @@ def forward(ctx, q, k, v, sm_scale): y_dim = q.shape[0] * q.shape[1] * q.shape[2] dummy_block = [1, 1] - desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_q = TensorDescriptor( + q, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) if q.dtype == torch.float8_e5m2: - desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[HEAD_DIM_K, y_dim], + strides=[q.shape[2], 1], + block_shape=dummy_block, + ) else: - desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_v = TensorDescriptor( + v, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) + desc_o = TensorDescriptor( + o, + shape=[y_dim, HEAD_DIM_K], + strides=[HEAD_DIM_K, 1], + block_shape=dummy_block, + ) def alloc_fn(size: int, align: int, _): return torch.empty(size, dtype=torch.int8, device="cuda") @@ -419,17 +624,28 @@ def alloc_fn(size: int, align: int, _): triton.set_allocator(alloc_fn) def grid(META): - return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + return ( + triton.cdiv(q.shape[2], META["BLOCK_M"]), + q.shape[0] * q.shape[1], + 1, + ) ctx.grid = grid _attn_fwd_ws[grid]( - sm_scale, M, # - q.shape[0], q.shape[1], # - desc_q, desc_k, desc_v, desc_o, # + sm_scale, + M, # + q.shape[0], + q.shape[1], # + desc_q, + desc_k, + desc_v, + desc_o, # N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # FP8_OUTPUT=q.dtype == torch.float8_e5m2, # - **extra_kern_args) + STAGE=stage, # + **extra_kern_args, + ) ctx.save_for_backward(q, k, v, o, M) ctx.sm_scale = sm_scale @@ -450,7 +666,8 @@ def grid(META): @pytest.mark.parametrize("HEAD_DIM", [128]) @pytest.mark.parametrize("mode", ["fwd"]) @pytest.mark.parametrize("provider", ["triton-fp16"]) -def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): +@pytest.mark.parametrize("causal", [True, False]) +def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16): if mode == "bwd": pytest.skip("Backward pass not supported.") torch.manual_seed(20) @@ -465,11 +682,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): q = q.to(ref_dtype) k = k.to(ref_dtype) v = v.to(ref_dtype) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - p = torch.softmax(p.float(), dim=-1) - p = p.to(ref_dtype) - # p = torch.exp(p) - ref_out = torch.matmul(p, v).half() + ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal) # triton implementation if mode == "fwd" and "fp8" in provider: q = q.to(torch.float8_e5m2) @@ -477,7 +690,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) - tri_out = attention(q, k, v, sm_scale).half() + tri_out = attention(q, k, v, sm_scale, causal).half() if mode == "fwd": atol = 3 if "fp8" in provider else 1e-2 torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) @@ -485,8 +698,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, dtype=torch.float16): try: - from flash_attn.flash_attn_interface import \ - flash_attn_qkvpacked_func as flash_attn_func + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as flash_attn_func, ) + HAS_FLASH = True except BaseException: HAS_FLASH = False @@ -537,7 +751,12 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, mode, provider, device=DEVI ms = triton.testing.do_bench(fn) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn( + (BATCH, N_CTX, 3, H, HEAD_DIM), + dtype=dtype, + device=device, + requires_grad=True, + ) fn = lambda: flash_attn_func(qkv) if mode == "bwd": o = fn()