From fa824f5f7af473b598788f65d6e78a2e9d7667e7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 06:27:00 +0000 Subject: [PATCH 1/4] add one shot pa kernel --- aiter/ops/triton/gluon/pa_decode_gluon.py | 33 +++++++++++++---------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 37e7ddd268..c715aeb5be 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -1481,12 +1481,15 @@ def paged_attention_decode_sliding_window( # ==================== SEQUENCE PROCESSING ==================== query_converted = query_shared.load(qk_lhs_operand_layout) - # query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_operand_layout) - sequence_partition_start_idx = ( - context_length - SLIDING_WINDOW - ) // CONTEXT_PARTITION_SIZE + + if SLIDING_WINDOW > 0: + sequence_partition_start_idx = ( + context_length - SLIDING_WINDOW + ) // CONTEXT_PARTITION_SIZE + else: + sequence_partition_start_idx = 0 sequence_partition_end_idx = gl.cdiv(context_length, CONTEXT_PARTITION_SIZE) - # num_iterations = sequence_partition_end_idx - sequence_partition_start_idx + if QUERY_QUANT_MODE < 0 and COMPUTE_TYPE.is_fp8(): # Quantize bf16 query to fp8 # Convert query to float32 for computation @@ -2549,7 +2552,14 @@ def paged_attention_decode_v2_reduce_kernel( head_size_offsets = tl.arange(0, HEAD_SIZE_POW2) # Initialize global accumulation variables - global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) + if USE_SINKS: + global_max = tl.load( + sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), + mask=query_group_offsets < query_group_size, + other=float("-inf"), + ).to(tl.float32) + else: + global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) global_max_prev = global_max global_exp_sum = tl.zeros((QUERY_GROUP_SIZE_POW2,), dtype=tl.float32) final_output = tl.zeros((QUERY_GROUP_SIZE_POW2, HEAD_SIZE_POW2), dtype=tl.float32) @@ -2596,13 +2606,6 @@ def paged_attention_decode_v2_reduce_kernel( global_exp_sum = update_scale * global_exp_sum + tl.sum(exp_sums, axis=0) global_max_prev = global_max - if USE_SINKS: - sink_token_values = gl.load( - sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), - mask=query_group_offsets < query_group_size, - ) - global_exp_sum += gl.exp(sink_token_values - global_max) - # ==================== SECOND PASS: COMPUTE RESCALED EXP SUMS AND ACCUMULATE ==================== for iter_idx in range(num_iterations): partition_base = iter_idx * MAX_CONTEXT_PARTITION_NUM @@ -2972,6 +2975,7 @@ def pa_decode_gluon( alibi_slopes: torch.Tensor = None, sinks: torch.Tensor = None, sliding_window: int = 0, + one_shot=None, ) -> None: """ Paged Attention Decode with FP8/BF16/FP16 Support. @@ -3263,7 +3267,8 @@ def pa_decode_gluon( fp8_max_value = torch.finfo(aiter.dtypes.fp8).max # ==================== ATTENTION DECODE KERNEL EXECUTION ==================== - one_shot = sliding_window > 0 + if one_shot is None: + one_shot = sliding_window > 0 _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper( grid, exp_sums, From 4c70a456cc8d6380bc8ed60adfab9fcfc611cfa2 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 08:55:00 +0000 Subject: [PATCH 2/4] fix buffer load in sliding window kernel --- aiter/ops/triton/gluon/pa_decode_gluon.py | 59 +++++++++-------------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index c715aeb5be..c78f8f3809 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -1463,12 +1463,22 @@ def paged_attention_decode_sliding_window( * stride_output_head + output_head_size_offsets[None, :] ) - max_logits = gl.full( - (QUERY_GROUP_SIZE_POW2,), - float("-inf"), - dtype=gl.float32, - layout=gl.SliceLayout(1, qk_linear_layout), - ) + if sinks_ptr is not None: + max_logits_offsets = gl.arange( + 0, QUERY_GROUP_SIZE_POW2, layout=gl.SliceLayout(1, qk_linear_layout) + ) + max_logits = gl.load( + sinks_ptr + (kv_head_idx * query_group_size + max_logits_offsets), + mask=max_logits_offsets < query_group_size, + other=float("-inf"), + ).to(gl.float32) + else: + max_logits = gl.full( + (QUERY_GROUP_SIZE_POW2,), + float("-inf"), + dtype=gl.float32, + layout=gl.SliceLayout(1, qk_linear_layout), + ) exp_sums = gl.full( (QUERY_GROUP_SIZE_POW2,), 0.0, @@ -1527,11 +1537,11 @@ def paged_attention_decode_sliding_window( ) # Create mask for valid blocks valid_block_mask = block_indices < num_kv_blocks - # masked_block_indices = gl.where(valid_block_mask, block_indices, 0) + masked_block_indices = gl.where(valid_block_mask, block_indices, 0) block_table_start_ptr = block_tables_ptr + sequence_idx * stride_block_table_seq kv_block_numbers = gl.amd.cdna3.buffer_load( - ptr=block_table_start_ptr + kv_block_start_idx, offsets=block_indices - ).to(gl.uint32) + ptr=block_table_start_ptr + kv_block_start_idx, offsets=masked_block_indices + ).to(gl.int64) # ==================== KEY LOADING AND PROCESSING ==================== # Calculate key cache offsets and load keys @@ -1543,20 +1553,15 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + contiguous_kv_element_offsets[None, None, None, :] ) - # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) - key_tensor = gl.amd.cdna3.buffer_load( - ptr=key_cache_ptr, - offsets=key_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) + key_tensor = gl.load(key_cache_ptr + key_block_offsets) # Prepare QK MFMA while key loads (these don't depend on key data) qk_accumulator = gl.zeros( (QUERY_GROUP_SIZE_POW2, CONTEXT_PARTITION_SIZE), dtype=gl.float32, - layout=qk_mfma_layout, + layout=qk_mfma_layout,q ) - # Load key quantization scales if needed (overlaps with key tensor load) if KV_QUANT_MODE >= 0: if KV_QUANT_MODE == 0: @@ -1625,11 +1630,7 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + value_dim3_offsets[None, None, None, :] ) - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1658,11 +1659,7 @@ def paged_attention_decode_sliding_window( ) # Schedule: Start value VMEM load, then QK MFMA - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1795,14 +1792,6 @@ def paged_attention_decode_sliding_window( # ==================== OUTPUT NORMALIZATION AND STORING ==================== # Normalize attention output by softmax denominator - if sinks_ptr is not None: - sinks_values = gl.load( - sinks_ptr + (kv_head_idx * query_group_size + query_group_offsets), - mask=query_group_offsets < query_group_size, - ) - exp_sums += gl.exp( - gl.convert_layout(sinks_values, layout=max_logits.type.layout) - max_logits - ) exp_sums_reciprocal = 1.0 / exp_sums exp_sums_reciprocal_cvt = gl.convert_layout( From 1020175e013b51e5bf55051babb461f4a5901123 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 08:57:06 +0000 Subject: [PATCH 3/4] fix typo --- aiter/ops/triton/gluon/pa_decode_gluon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index c78f8f3809..e6de15218e 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -1560,7 +1560,7 @@ def paged_attention_decode_sliding_window( qk_accumulator = gl.zeros( (QUERY_GROUP_SIZE_POW2, CONTEXT_PARTITION_SIZE), dtype=gl.float32, - layout=qk_mfma_layout,q + layout=qk_mfma_layout, ) # Load key quantization scales if needed (overlaps with key tensor load) if KV_QUANT_MODE >= 0: From c6963fdcda6e647ea37e3581fb1df941940f3ade Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 11:15:50 +0000 Subject: [PATCH 4/4] revert --- aiter/ops/triton/gluon/pa_decode_gluon.py | 53 +++++++++++++---------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index e6de15218e..9c07d4c4ac 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -1463,22 +1463,13 @@ def paged_attention_decode_sliding_window( * stride_output_head + output_head_size_offsets[None, :] ) - if sinks_ptr is not None: - max_logits_offsets = gl.arange( - 0, QUERY_GROUP_SIZE_POW2, layout=gl.SliceLayout(1, qk_linear_layout) - ) - max_logits = gl.load( - sinks_ptr + (kv_head_idx * query_group_size + max_logits_offsets), - mask=max_logits_offsets < query_group_size, - other=float("-inf"), - ).to(gl.float32) - else: - max_logits = gl.full( - (QUERY_GROUP_SIZE_POW2,), - float("-inf"), - dtype=gl.float32, - layout=gl.SliceLayout(1, qk_linear_layout), - ) + + max_logits = gl.full( + (QUERY_GROUP_SIZE_POW2,), + float("-inf"), + dtype=gl.float32, + layout=gl.SliceLayout(1, qk_linear_layout), + ) exp_sums = gl.full( (QUERY_GROUP_SIZE_POW2,), 0.0, @@ -1790,6 +1781,14 @@ def paged_attention_decode_sliding_window( attention_accumulator += attention_output max_logits = new_max_logits + if sinks_ptr is not None: + sinks_values = gl.load( + sinks_ptr + (kv_head_idx * query_group_size + query_group_offsets), + mask=query_group_offsets < query_group_size, + ) + exp_sums += gl.exp( + gl.convert_layout(sinks_values, layout=max_logits.type.layout) - max_logits + ) # ==================== OUTPUT NORMALIZATION AND STORING ==================== # Normalize attention output by softmax denominator @@ -2541,14 +2540,14 @@ def paged_attention_decode_v2_reduce_kernel( head_size_offsets = tl.arange(0, HEAD_SIZE_POW2) # Initialize global accumulation variables - if USE_SINKS: - global_max = tl.load( - sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), - mask=query_group_offsets < query_group_size, - other=float("-inf"), - ).to(tl.float32) - else: - global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) + # if USE_SINKS: + # global_max = tl.load( + # sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), + # mask=query_group_offsets < query_group_size, + # other=float("-inf"), + # ).to(tl.float32) + # else: + global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) global_max_prev = global_max global_exp_sum = tl.zeros((QUERY_GROUP_SIZE_POW2,), dtype=tl.float32) final_output = tl.zeros((QUERY_GROUP_SIZE_POW2, HEAD_SIZE_POW2), dtype=tl.float32) @@ -2595,6 +2594,12 @@ def paged_attention_decode_v2_reduce_kernel( global_exp_sum = update_scale * global_exp_sum + tl.sum(exp_sums, axis=0) global_max_prev = global_max + if USE_SINKS: + sink_token_values = gl.load( + sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), + mask=query_group_offsets < query_group_size, + ) + global_exp_sum += gl.exp(sink_token_values - global_max) # ==================== SECOND PASS: COMPUTE RESCALED EXP SUMS AND ACCUMULATE ==================== for iter_idx in range(num_iterations): partition_base = iter_idx * MAX_CONTEXT_PARTITION_NUM