diff --git a/aiter/ops/triton/_triton_kernels/unified_attention.py b/aiter/ops/triton/_triton_kernels/unified_attention.py index e22fce3932..811d3405ef 100644 --- a/aiter/ops/triton/_triton_kernels/unified_attention.py +++ b/aiter/ops/triton/_triton_kernels/unified_attention.py @@ -9,6 +9,12 @@ float8_info = torch.finfo(e4m3_dtype) +@triton.jit +def fast_exp(x): + RCP_LN2: tl.constexpr = 1.4426950408889634 + return tl.math.exp2(x * RCP_LN2) + + @triton.jit def cdiv_fn(x, y): return (x + y - 1) // y @@ -17,8 +23,8 @@ def cdiv_fn(x, y): @triton.jit def apply_softcap(S, x): Sdiv = S / x - p1 = tl.exp(Sdiv) - p2 = tl.exp(-Sdiv) + p1 = tl.math.exp2(Sdiv) + p2 = tl.math.exp2(-Sdiv) return x * (p1 - p2) / (p1 + p2) @@ -56,7 +62,7 @@ def kernel_unified_attention_2d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 + scale: tl.constexpr, # float32 k_scale, # float32 v_scale, # float32 out_scale, # float32 @@ -70,6 +76,7 @@ def kernel_unified_attention_2d( output_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -92,24 +99,14 @@ def kernel_unified_attention_2d( USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, - ALL_DECODE: tl.constexpr = False, + ALL_DECODE: tl.constexpr = False, # bool ): kv_head_idx = tl.program_id(0) q_block_global_idx = tl.program_id(1) - tl.assume(kv_head_idx >= 0) - tl.assume(q_block_global_idx >= 0) - tl.assume(block_table_stride > 0) - tl.assume(query_stride_0 > 0) - tl.assume(query_stride_1 > 0) - tl.assume(output_stride_0 > 0) - tl.assume(output_stride_1 > 0) - tl.assume(stride_k_cache_0 > 0) - tl.assume(stride_k_cache_1 > 0) - tl.assume(stride_k_cache_2 > 0) - tl.assume(stride_v_cache_0 > 0) - tl.assume(stride_v_cache_1 > 0) - tl.assume(stride_v_cache_2 > 0) + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = scale * RCP_LN2 seq_idx = find_seq_idx( query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True @@ -129,6 +126,7 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -139,7 +137,10 @@ def kernel_unified_attention_2d( + offs_d[None, :] ) - dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) query_mask_0 = query_pos < cur_batch_query_len query_mask_1 = query_offset_1 < num_query_heads @@ -147,7 +148,6 @@ def kernel_unified_attention_2d( Q_cache_modifier: tl.constexpr = ".cg" else: Q_cache_modifier: tl.constexpr = "" - # Q : (BLOCK_M, HEAD_SIZE_PADDED) Q = tl.load( query_ptr + query_offset, @@ -161,11 +161,15 @@ def kernel_unified_attention_2d( if not USE_SINKS: M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) else: - M = tl.load( - sink_ptr + query_offset_1, - mask=query_mask_1, - other=float("-inf"), - ).to(dtype=tl.float32) + # Prescale with RCP_LN2, needed for exp2 + M = ( + tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + * RCP_LN2 + ) L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -201,43 +205,65 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles if SLIDING_WINDOW > 0: - num_blocks_start = ( - max_seq_prefix_len - SLIDING_WINDOW - BLOCK_Q - 1 - ) // BLOCK_SIZE - num_blocks_start = max(0, num_blocks_start) - else: - num_blocks_start = 0 - KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" - # iterate through tiles - for j in range(num_blocks_start, num_blocks): + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + # iterate through tiles (now limited to the sliding window range) + for j in range(tile_start, tile_end): + seq_offset = j * TILE_SIZE + offs_t + # to reduce the masking effect when not needed + if TILE_SIZE == BLOCK_SIZE: + tile_mask = tl.full((1,), 1, dtype=tl.int1) + else: + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) v_offset = ( - physical_block_idx * stride_v_cache_0 + physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 - + offs_n[:, None] * stride_v_cache_1 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) k_offset = ( - physical_block_idx * stride_k_cache_0 + physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 - + offs_n[None, :] * stride_k_cache_1 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load( key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0, cache_modifier=KV_cache_modifier, ) @@ -250,10 +276,10 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0, cache_modifier=KV_cache_modifier, ) @@ -266,17 +292,15 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - - S += scale * tl.dot(Q, K) + # S : (BLOCK_M, TILE_SIZE) + # qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later + S = qk_scale * tl.dot(Q, K) if USE_SOFTCAP: - S = apply_softcap(S, softcap) + # softcap here uses exp2 and consumes RCP_LN2 conversion. + # multiply by RCP_LN2 again to be used in later exp2 + S = apply_softcap(S, softcap) * RCP_LN2 + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 S = tl.where( query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") @@ -290,7 +314,8 @@ def kernel_unified_attention_2d( ) if USE_ALIBI_SLOPES: - S += alibi_slope[:, None] * (seq_offset - context_len) + # prescale w. RCP_LN2 for later exp2 + S += alibi_slope[:, None] * (seq_offset - context_len) * RCP_LN2 if USE_QQ_BIAS: # compute key positions relative to query section @@ -302,23 +327,25 @@ def kernel_unified_attention_2d( mask=is_query_key[None, :], # avoid OOB for context keys other=0.0, ) - S += qq_bias + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * RCP_LN2 # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) - P = tl.exp(S - m_j[:, None]) + # P : (BLOCK_M, TILE_SIZE) + P = tl.math.exp2(S - m_j[:, None]) # l_j : (BLOCK_M,) l_j = tl.sum(P, axis=1) # alpha : (BLOCK_M, ) - alpha = tl.exp(M - m_j) + alpha = tl.math.exp2(M - m_j) # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc = acc * alpha[:, None] @@ -331,6 +358,7 @@ def kernel_unified_attention_2d( acc += tl.dot(P.to(V.dtype), V) # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. one_over_L = 1.0 / L[:, None] acc = acc * one_over_L if USE_FP8: @@ -357,7 +385,7 @@ def kernel_unified_attention_3d( segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] @@ -375,6 +403,7 @@ def kernel_unified_attention_3d( query_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -395,25 +424,15 @@ def kernel_unified_attention_3d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int - ALL_DECODE: tl.constexpr, + ALL_DECODE: tl.constexpr = False, # bool ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - tl.assume(kv_head_idx >= 0) - tl.assume(q_block_global_idx >= 0) - tl.assume(segm_idx >= 0) - - tl.assume(block_table_stride > 0) - tl.assume(query_stride_0 > 0) - tl.assume(query_stride_1 > 0) - tl.assume(stride_k_cache_0 > 0) - tl.assume(stride_k_cache_1 > 0) - tl.assume(stride_k_cache_2 > 0) - tl.assume(stride_v_cache_0 > 0) - tl.assume(stride_v_cache_1 > 0) - tl.assume(stride_v_cache_2 > 0) + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = scale * RCP_LN2 seq_idx = find_seq_idx( query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True @@ -436,30 +455,30 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :] ) - dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) - query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - - KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < num_query_heads # Q : (BLOCK_M, HEAD_SIZE_PADDED) Q = tl.load( @@ -472,11 +491,15 @@ def kernel_unified_attention_3d( if USE_SINKS: if segm_idx == 0: - M = tl.load( - sink_ptr + query_offset_1, - mask=query_mask_1, - other=float("-inf"), - ).to(dtype=tl.float32) + # Prescale with RCP_LN2, needed for exp2 + M = ( + tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + * RCP_LN2 + ) else: M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) else: @@ -500,35 +523,58 @@ def kernel_unified_attention_3d( qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + if TILE_SIZE == BLOCK_SIZE: + tile_mask = tl.full((1,), 1, dtype=tl.int1) + else: + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) v_offset = ( - physical_block_idx * stride_v_cache_0 + physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 - + offs_n[:, None] * stride_v_cache_1 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) k_offset = ( - physical_block_idx * stride_k_cache_0 + physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 - + offs_n[None, :] * stride_k_cache_1 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load( key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0, cache_modifier=KV_cache_modifier, ) @@ -541,10 +587,10 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0, cache_modifier=KV_cache_modifier, ) @@ -557,17 +603,16 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - - S += scale * tl.dot(Q, K) + # S : (BLOCK_M, TILE_SIZE) + # qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later + S = qk_scale * tl.dot(Q, K) if USE_SOFTCAP: - S = apply_softcap(S, softcap) + # softcap here uses exp2 and consumes RCP_LN2 conversion. + # multiply by RCP_LN2 again to be used in later exp2 + S = apply_softcap(S, softcap) * RCP_LN2 S = tl.where( query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") @@ -581,7 +626,8 @@ def kernel_unified_attention_3d( ) if USE_ALIBI_SLOPES: - S += alibi_slope[:, None] * (seq_offset - context_len) + # prescale w. RCP_LN2 for later exp2 + S += alibi_slope[:, None] * (seq_offset - context_len) * RCP_LN2 if USE_QQ_BIAS: # compute key positions relative to query section @@ -593,23 +639,25 @@ def kernel_unified_attention_3d( mask=is_query_key[None, :], # avoid OOB for context keys other=0.0, ) - S += qq_bias + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * RCP_LN2 # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) - P = tl.exp(S - m_j[:, None]) + # P : (BLOCK_M, TILE_SIZE,) + P = tl.math.exp2(S - m_j[:, None]) # l_j : (BLOCK_M,) l_j = tl.sum(P, axis=1) # alpha : (BLOCK_M, ) - alpha = tl.exp(M - m_j) + alpha = tl.math.exp2(M - m_j) # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc = acc * alpha[:, None] @@ -656,7 +704,7 @@ def reduce_segments( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 query_start_len_ptr, # [num_seqs+1] @@ -678,14 +726,18 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 ) - dim_mask = tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE + + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) # load segment maxima segm_offset = ( @@ -698,7 +750,7 @@ def reduce_segments( # load and rescale segment exp sums segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) - segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + segm_expsum = segm_expsum * tl.math.exp2(segm_max - overall_max) overall_expsum = tl.sum(segm_expsum) # load, rescale, and add segment attention outputs @@ -714,7 +766,7 @@ def reduce_segments( mask=segm_mask[:, None] & dim_mask[None, :], other=0.0, ) - segm_output *= tl.exp(segm_max - overall_max)[:, None] + segm_output *= tl.math.exp2(segm_max - overall_max)[:, None] acc_sum = tl.sum(segm_output, axis=0) # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) diff --git a/aiter/ops/triton/unified_attention.py b/aiter/ops/triton/unified_attention.py index d10c1dcd3d..b2231ee563 100644 --- a/aiter/ops/triton/unified_attention.py +++ b/aiter/ops/triton/unified_attention.py @@ -11,6 +11,95 @@ ) +def select_2d_config( + block_size, + head_size, + sliding_window, + all_decode, + max_seqlen_q, + max_seqlen_k, + num_queries_per_kv, + num_2d_prgms, +): + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + TILE_SIZE = 64 + # in case head_size is large + max_num_stages_2d = 4 + if head_size > 128: + max_num_stages_2d = 2 + if all_decode == False: + num_stages_2d = 1 + num_warps = 2 + else: + num_stages_2d = 3 + num_warps = 2 + TILE_SIZE = block_size + + if max_seqlen_q >= 256: + BLOCK_M = 128 + num_stages_2d = 1 + num_warps = 4 + BLOCK_Q = BLOCK_M // num_queries_per_kv + num_stages_2d = min(max_num_stages_2d, num_stages_2d) + return { + "BLOCK_M": BLOCK_M, + "BLOCK_Q": BLOCK_Q, + "TILE_SIZE": TILE_SIZE, + "num_warps": num_warps, + "num_stages": num_stages_2d, + "waves_per_eu": 2, + } + + +def select_3d_config( + head_size, block_size, element_size, max_seqlen_k, target_num_prgms, num_2d_prgms +): + reduce_num_warps = 2 + attn_warps = 2 + TILE_SIZE = block_size + MAX_SEGMENTS = min(128, math.ceil(max_seqlen_k / TILE_SIZE)) + num_segments = math.ceil(target_num_prgms / num_2d_prgms) + num_segments = triton.next_power_of_2(num_segments) + num_segments = min(num_segments, 128) + MIN_SEGMENTS = 16 if TILE_SIZE <= 16 else 8 + num_segments = max(num_segments, MIN_SEGMENTS) + if num_segments == MIN_SEGMENTS: + reduce_num_warps = 1 + attn_config = { + "TILE_SIZE": TILE_SIZE, + "NUM_SEGMENTS_PER_SEQ": num_segments, + "num_warps": attn_warps, + "num_stages": 1, + "waves_per_eu": 2, + } + reduce_config = { + "TILE_SIZE": TILE_SIZE, + "NUM_SEGMENTS_PER_SEQ": num_segments, + "num_warps": reduce_num_warps, + "num_stages": 1, + "waves_per_eu": 2, + } + return attn_config, reduce_config + + +def use_2d_kernel( + head_size, + sliding_window, + all_decode, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, +): + return ( + (sliding_window > 0) + or (max_seqlen_k <= 512) + or (num_2d_prgms > target_num_prgms) + ) + + def unified_attention( q, k, @@ -37,17 +126,13 @@ def unified_attention( assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert ( - q.element_size() >= 2 or block_size >= 32 - ), "Block size must be at least 32 for fp8" - if sinks is not None: assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None SLIDING_WINDOW = 1 + window_size[0] + block_size = v.shape[1] num_seqs = len(seqused_k) num_query_heads = q.shape[1] @@ -55,11 +140,11 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) BLOCK_Q = BLOCK_M // num_queries_per_kv - if BLOCK_Q == 0: - BLOCK_M = triton.next_power_of_2(num_queries_per_kv) - BLOCK_Q = BLOCK_M // num_queries_per_kv + assert BLOCK_Q >= 1 # Ideally we would launch with kernel with: # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. # However, it is slow to realize the query_lens on cpu. @@ -69,40 +154,34 @@ def unified_attention( # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs # = floor(q.shape[0] / BLOCK_Q) + num_seqs - + cu_count = get_num_sms() total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - target_num_prgms = get_num_sms() * 4 + target_num_prgms = cu_count * 4 num_2d_prgms = total_num_q_blocks * num_kv_heads ALL_DECODE = max_seqlen_q == 1 + # if batch contains a prefill + if use_2d_kernel( + head_size, + SLIDING_WINDOW, + ALL_DECODE, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + ): + config = select_2d_config( + block_size, + head_size, + SLIDING_WINDOW, + ALL_DECODE, + max_seqlen_q, + max_seqlen_k, + num_queries_per_kv, + num_2d_prgms, + ) + assert config["BLOCK_Q"] >= 1 + total_num_q_blocks = q.shape[0] // config["BLOCK_Q"] + num_seqs - # call 2d if sliding window is used - if SLIDING_WINDOW > 0 or num_2d_prgms >= target_num_prgms or max_seqlen_k <= 1024: - if ALL_DECODE == False: - num_stages_2d = 4 - num_warps = 4 - else: - num_stages_2d = 3 - num_warps = 2 - # make the block_m bigger if we already have enough parallelism - if num_2d_prgms >= 2 * target_num_prgms: - if num_2d_prgms <= 4 * target_num_prgms: - BLOCK_M = 64 - num_stages_2d = 2 if SLIDING_WINDOW > 0 else 4 - elif num_2d_prgms <= 8 * target_num_prgms: - BLOCK_M = 64 - num_stages_2d = 1 if SLIDING_WINDOW > 0 else 2 - else: - BLOCK_M = 64 - num_stages_2d = 1 - BLOCK_Q = BLOCK_M // num_queries_per_kv - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - - if max_seqlen_q >= 512 and block_size == 64: - BLOCK_M = 128 - num_stages_2d = 1 - num_warps = 4 - BLOCK_Q = BLOCK_M // num_queries_per_kv - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs kernel_unified_attention_2d[ ( num_kv_heads, @@ -148,21 +227,22 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, - BLOCK_M=BLOCK_M, USE_FP8=output_scale is not None, ALL_DECODE=ALL_DECODE, - waves_per_eu=2, - num_warps=num_warps, - num_stages=num_stages_2d, + **config, ) + else: - NUM_SEGMENTS = math.ceil(target_num_prgms / num_2d_prgms) - NUM_SEGMENTS = triton.next_power_of_2(NUM_SEGMENTS) - NUM_SEGMENTS = min(NUM_SEGMENTS, 256) - MIN_SEGMENTS = 16 if block_size <= 16 else 8 - NUM_SEGMENTS = max(NUM_SEGMENTS, MIN_SEGMENTS) + attn_config, reduce_config = select_3d_config( + head_size, + block_size, + q.element_size(), + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + ) + NUM_SEGMENTS = attn_config["NUM_SEGMENTS_PER_SEQ"] segm_output = torch.empty( q.shape[0], num_query_heads, @@ -215,7 +295,7 @@ def unified_attention( USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), + SLIDING_WINDOW=SLIDING_WINDOW, stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), @@ -228,13 +308,9 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ALL_DECODE=ALL_DECODE, - waves_per_eu=2, - num_stages=1, - num_warps=2, + **attn_config, ) - reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -247,12 +323,10 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, USE_FP8=output_scale is not None, - num_stages=1, + **reduce_config, )