Skip to content

[ROCm] Add Aiter PagedAttention with Sliding Window support#28719

Closed
sammysun0711 wants to merge 9 commits intovllm-project:mainfrom
sammysun0711:add_aiter_pa_sliding_windows_support
Closed

[ROCm] Add Aiter PagedAttention with Sliding Window support#28719
sammysun0711 wants to merge 9 commits intovllm-project:mainfrom
sammysun0711:add_aiter_pa_sliding_windows_support

Conversation

@sammysun0711
Copy link
Contributor

@sammysun0711 sammysun0711 commented Nov 14, 2025

Purpose

This PR aim to add Aiter PagedAttention (PA) with sliding window support, which can fix google/gemma-3-27b-it bfloat16 model accuracy issue with Aiter PA.

google/gemma-3-27b-it trained with additional parameter "sliding_window": 1024.

  • Triton unified attention (default) unified_attention passed sliding windows as input:

    unified_attention(
    q=query[:num_actual_tokens],
    k=key_cache,
    v=value_cache,
    out=output[:num_actual_tokens],
    cu_seqlens_q=cu_seqlens_q,
    max_seqlen_q=max_seqlen_q,
    seqused_k=seqused_k,
    max_seqlen_k=max_seqlen_k,
    softmax_scale=self.scale,
    causal=True,
    alibi_slopes=self.alibi_slopes,
    window_size=self.sliding_window,
    block_table=block_table,
    softcap=self.logits_soft_cap,
    q_descale=None, # Not supported
    k_descale=layer._k_scale.expand(descale_shape),
    v_descale=layer._v_scale.expand(descale_shape),
    sinks=self.sinks,
    output_scale=output_scale,
    )

  • Aiter PA torch.ops.aiter.paged_attention_v1 does not pass sliding_windows as input:

    torch.ops.aiter.paged_attention_v1(
    output[:num_decode_tokens],
    workspace_buffer,
    query[:num_decode_tokens],
    key_cache,
    value_cache,
    self.scale,
    attn_metadata.block_table[:num_decodes],
    attn_metadata.query_start_loc[:num_decodes],
    attn_metadata.seq_lens[:num_decodes],
    attn_metadata.max_seq_len,
    self.alibi_slopes,
    self.kv_cache_dtype,
    "NHD",
    self.logits_soft_cap,
    layer._k_scale,
    layer._v_scale,
    None,
    _PARTITION_SIZE_ROCM,

If input prompt token is large than 1024, missing handling sliding_windows cause the gemma3 accuracy degradation with Aiter PA.

To fixed gemma3 accuracy issue, following 3 PR required:

Open as draft PR for now since it depends on other 2 PRs.

Test Plan

lm_eval test with gsm8k dataset

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Nov 14, 2025
Signed-off-by: Xiake Sun <xiake.sun@amd.com>
Signed-off-by: Xiake Sun <xiake.sun@amd.com>
@sammysun0711
Copy link
Contributor Author

Sorry, need to close due to rebase issue, continue in new PR: #29065.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant