Add Paged Attention Op for CUDA SM80 support#24595
Merged
aciddelgado merged 16 commits intomainfrom Jun 12, 2025
Merged
Conversation
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
May 1, 2025
tianleiwu
reviewed
Jun 11, 2025
tianleiwu
reviewed
Jun 11, 2025
tianleiwu
reviewed
Jun 11, 2025
tianleiwu
reviewed
Jun 11, 2025
tianleiwu
reviewed
Jun 11, 2025
Contributor
tianleiwu
left a comment
There was a problem hiding this comment.
What design change needed if we want to support FP8 or FP4 paged attention in the future?
tianleiwu
reviewed
Jun 11, 2025
Contributor
Author
New kernel necessary |
tianleiwu
approved these changes
Jun 12, 2025
ankus-qti
pushed a commit
to CodeLinaro/onnxruntime
that referenced
this pull request
Nov 25, 2025
### Description Adds Paged Attention Op which enables of Paged KV Cache. Inputs to this op are unpadded (packed / varlen) so Cumulative Sequence Lengths are a required input. ### Motivation and Context Adding this op to ONNXRuntime is necessary to allow the GenAI team to enable a continuous batching server API.
tianleiwu
pushed a commit
that referenced
this pull request
Apr 28, 2026
…ttention (#28200) ### Description Adds a CUTLASS memory-efficient attention (MEA) fallback to the CUDA PagedAttention op, enabling the operator on **sm<80 (Turing / Volta / Pascal) with fp16** for the first time. On sm>=80 the default FlashAttention path is unchanged; MEA is reachable via `ORT_DISABLE_FLASH_ATTENTION=1` or the `sdpa_kernel` CUDA provider option for debugging and perf comparison. | Environment | Before | After | |---|:---:|:---:| | sm<80 + fp16 | ❌ error | ✅ MEA | | sm<80 + bf16 | ❌ error | ❌ error (MEA requires sm>=80 for bf16) | | sm>=80 + fp16/bf16 (default) | ✅ FA | ✅ FA (unchanged) | | sm>=80 + `ORT_DISABLE_FLASH_ATTENTION=1` / `sdpa_kernel=EFFICIENT_ATTENTION` | ❌ error | ✅ MEA | ### Motivation and Context The original PagedAttention PR (#24595) landed with the title "CUDA SM80 support" — the op errors out immediately whenever FlashAttention isn't available (sm<80 or `USE_FLASH_ATTENTION=0` builds). During that review, @tianleiwu flagged that the interface was too FlashAttention-specific (*"not good for other EP like WebGPU, CPU etc."*) and @aciddelgado agreed the FA-specific dependencies could be lifted at the kernel level. This PR closes that gap for sm<80 fp16 by mirroring the exact pattern established in #20012 ("Packed QKV and Rotary Embedding Support for sm<80 GQA"). The same CUTLASS memory-efficient attention backend that covers GQA's sm<80 path now covers PagedAttention. Related work: - #20012 — direct pattern template (sm<80 GQA MEA fallback) - #24595 — original PagedAttention PR - #27516 — MS canonical FA → MEA → Unfused cascade ordering - #27880 — ONNX Attention CUDA fallback coverage gaps - #27992 — MEA decode + unfused softcap work (same flavor) ### Implementation **Dispatch cascade** in `paged_attention.cc`: FlashAttention preferred; fall back to MemoryEfficientAttention via `has_memory_efficient_attention(sm, is_half, is_bf16, head_size, head_size)`. No custom head-size or dtype bounds hardcoded — MEA's own helper gates fp16 sm>=53 / bf16 sm>=80 / head_size <= 1024 and `% 8 == 0`. This keeps us forward-compatible with any future expansion of MEA's supported range. **MEA path** (`UnfusedAttention<T>`): 1. Reuses existing preprocessing: `LaunchGetCumulativeSeqlensKV` (hoisted to `paged_attention.cc` so both FA and MEA paths consume a pre-populated buffer — single-producer refactor), rotary, packed-QKV unpack, `ReshapeAndCache`. 2. New `GatherAndExpandPagedKVCache` CUDA kernel walks `block_table` to gather paged K/V into a packed-varlen `[total_kv_tokens, num_heads, head_size]` buffer, folding in GQA head expansion (so downstream MEA sees `num_heads` uniformly). 3. Dispatches to `run_memory_efficient_attention` in **varlen mode** via `seqstart_q_ptr = cumulative_seqlens_q` + `seqstart_k_ptr = cumulative_seqlens_kv` (and `has_custom_right_padding = false`). No padding required; layout matches the kernel's expected `[total_tokens, num_heads, head_size]` with BSNH strides. **Scratch allocation**: the MEA path D->H syncs `cumulative_seqlens_kv[batch_size]` via a pinned buffer to obtain `total_kv_tokens` on the host for tight `gathered_key` / `gathered_value` / `fmha_buffer` allocation. This adds a forward-per-call `cudaStreamSynchronize` — acceptable for a compatibility fallback (FA remains the hot path on supported hardware). Over-allocation (the no-sync alternative) would consume `B × max_num_blocks_per_seq × block_size × num_heads × head_size × 2 × sizeof(T)`, which reaches GB-scale for realistic GQA models and was rejected. `fmha_buffer` is sized with `sizeof(float)` (matching the GQA EfficientAttention pattern at `group_query_attention.cc:482`) because MEA's output accumulator is fp32 regardless of input dtype. ### Testing New `TestPagedAttentionMEA` class in `test_paged_attention_cuda.py` runs the existing parity matrix (rotary on/off, rotary_interleaved on/off, packed-QKV on/off, local window on/off, softcap 0/50, varied head sizes/shapes) against the MEA path via the `sdpa_kernel` CUDA provider option set to `EFFICIENT_ATTENTION` (=2, from `AttentionBackend` enum). Using a per-session provider option instead of an env var means both FA and MEA test classes coexist in the same pytest process — each InferenceSession creates its own CUDA EP with its own `attention_kernel_options_`. The existing `TestPagedAttention` class is skipped wholesale on sm<80 by its `has_flash_attention()` gate, so without the new MEA class the fallback path would have no CI coverage. **Local verification** (NVIDIA A100 80GB, CUDA 12.8, GCC 13.3): ``` TestPagedAttention: 24/24 passed (~60s) # FA baseline — no regression TestPagedAttentionMEA: 24/24 passed (~59s) # new MEA path ``` Tolerance: `rtol = atol = 5e-3` against the same torch reference used by the FA parity test. All combinations match. **sm<80 hardware coverage**: I don't have local Turing / Volta / Pascal hardware, so real-SM coverage relies on MS CI. The code path exercised on A100 via `sdpa_kernel=EFFICIENT_ATTENTION` is the same one taken on sm<80; only the underlying CUTLASS kernel (`run_memory_efficient_attention_sm50/70/75/80`) differs per SM, and those are upstream and unmodified by this change. **Build note**: built with `--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20`. The explicit C++20 define was needed because the initial configure resolved `CMAKE_CXX_STANDARD=17`, under which `ort_version_check.h`'s `consteval` usage fails to compile. Unrelated to this change.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds Paged Attention Op which enables of Paged KV Cache. Inputs to this op are unpadded (packed / varlen) so Cumulative Sequence Lengths are a required input.
Motivation and Context
Adding this op to ONNXRuntime is necessary to allow the GenAI team to enable a continuous batching server API.