Skip to content

Add Paged Attention Op for CUDA SM80 support#24595

Merged
aciddelgado merged 16 commits intomainfrom
aciddelgado/paged_attention
Jun 12, 2025
Merged

Add Paged Attention Op for CUDA SM80 support#24595
aciddelgado merged 16 commits intomainfrom
aciddelgado/paged_attention

Conversation

@aciddelgado
Copy link
Copy Markdown
Contributor

Description

Adds Paged Attention Op which enables of Paged KV Cache. Inputs to this op are unpadded (packed / varlen) so Cumulative Sequence Lengths are a required input.

Motivation and Context

Adding this op to ONNXRuntime is necessary to allow the GenAI team to enable a continuous batching server API.

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc Fixed
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h Fixed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Fixed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Fixed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Fixed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Fixed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Dismissed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Dismissed
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py Dismissed
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_data.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Outdated
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What design change needed if we want to support FP8 or FP4 paged attention in the future?

Comment thread onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h Outdated
@aciddelgado
Copy link
Copy Markdown
Contributor Author

What design change needed if we want to support FP8 or FP4 paged attention in the future?

New kernel necessary

@aciddelgado aciddelgado merged commit 30c5f05 into main Jun 12, 2025
90 checks passed
@aciddelgado aciddelgado deleted the aciddelgado/paged_attention branch June 12, 2025 17:40
ankus-qti pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Nov 25, 2025
### Description
Adds Paged Attention Op which enables of Paged KV Cache. Inputs to this
op are unpadded (packed / varlen) so Cumulative Sequence Lengths are a
required input.


### Motivation and Context
Adding this op to ONNXRuntime is necessary to allow the GenAI team to
enable a continuous batching server API.
tianleiwu pushed a commit that referenced this pull request Apr 28, 2026
…ttention (#28200)

### Description

Adds a CUTLASS memory-efficient attention (MEA) fallback to the CUDA
PagedAttention op, enabling the operator on **sm<80 (Turing / Volta /
Pascal) with fp16** for the first time. On sm>=80 the default
FlashAttention path is unchanged; MEA is reachable via
`ORT_DISABLE_FLASH_ATTENTION=1` or the `sdpa_kernel` CUDA provider
option for debugging and perf comparison.


| Environment | Before | After |
|---|:---:|:---:|
| sm<80 + fp16 | ❌ error | ✅ MEA |
| sm<80 + bf16 | ❌ error | ❌ error (MEA requires sm>=80 for bf16) |
| sm>=80 + fp16/bf16 (default) | ✅ FA | ✅ FA (unchanged) |
| sm>=80 + `ORT_DISABLE_FLASH_ATTENTION=1` /
`sdpa_kernel=EFFICIENT_ATTENTION` | ❌ error | ✅ MEA |

### Motivation and Context

The original PagedAttention PR (#24595) landed with the title "CUDA SM80
support" — the op errors out immediately whenever FlashAttention isn't
available (sm<80 or `USE_FLASH_ATTENTION=0` builds). During that review,
@tianleiwu flagged that the interface was too FlashAttention-specific
(*"not good for other EP like WebGPU, CPU etc."*) and @aciddelgado
agreed the FA-specific dependencies could be lifted at the kernel level.

This PR closes that gap for sm<80 fp16 by mirroring the exact pattern
established in #20012 ("Packed QKV and Rotary Embedding Support for
sm<80 GQA"). The same CUTLASS memory-efficient attention backend that
covers GQA's sm<80 path now covers PagedAttention.

Related work:
- #20012 — direct pattern template (sm<80 GQA MEA fallback)
- #24595 — original PagedAttention PR
- #27516 — MS canonical FA → MEA → Unfused cascade ordering
- #27880 — ONNX Attention CUDA fallback coverage gaps
- #27992 — MEA decode + unfused softcap work (same flavor)

### Implementation

**Dispatch cascade** in `paged_attention.cc`: FlashAttention preferred;
fall back to MemoryEfficientAttention via
`has_memory_efficient_attention(sm, is_half, is_bf16, head_size,
head_size)`. No custom head-size or dtype bounds hardcoded — MEA's own
helper gates fp16 sm>=53 / bf16 sm>=80 / head_size <= 1024 and `% 8 ==
0`. This keeps us forward-compatible with any future expansion of MEA's
supported range.

**MEA path** (`UnfusedAttention<T>`):
1. Reuses existing preprocessing: `LaunchGetCumulativeSeqlensKV`
(hoisted to `paged_attention.cc` so both FA and MEA paths consume a
pre-populated buffer — single-producer refactor), rotary, packed-QKV
unpack, `ReshapeAndCache`.
2. New `GatherAndExpandPagedKVCache` CUDA kernel walks `block_table` to
gather paged K/V into a packed-varlen `[total_kv_tokens, num_heads,
head_size]` buffer, folding in GQA head expansion (so downstream MEA
sees `num_heads` uniformly).
3. Dispatches to `run_memory_efficient_attention` in **varlen mode** via
`seqstart_q_ptr = cumulative_seqlens_q` + `seqstart_k_ptr =
cumulative_seqlens_kv` (and `has_custom_right_padding = false`). No
padding required; layout matches the kernel's expected `[total_tokens,
num_heads, head_size]` with BSNH strides.

**Scratch allocation**: the MEA path D->H syncs
`cumulative_seqlens_kv[batch_size]` via a pinned buffer to obtain
`total_kv_tokens` on the host for tight `gathered_key` /
`gathered_value` / `fmha_buffer` allocation. This adds a
forward-per-call `cudaStreamSynchronize` — acceptable for a
compatibility fallback (FA remains the hot path on supported hardware).
Over-allocation (the no-sync alternative) would consume `B ×
max_num_blocks_per_seq × block_size × num_heads × head_size × 2 ×
sizeof(T)`, which reaches GB-scale for realistic GQA models and was
rejected.

`fmha_buffer` is sized with `sizeof(float)` (matching the GQA
EfficientAttention pattern at `group_query_attention.cc:482`) because
MEA's output accumulator is fp32 regardless of input dtype.

### Testing

New `TestPagedAttentionMEA` class in `test_paged_attention_cuda.py` runs
the existing parity matrix (rotary on/off, rotary_interleaved on/off,
packed-QKV on/off, local window on/off, softcap 0/50, varied head
sizes/shapes) against the MEA path via the `sdpa_kernel` CUDA provider
option set to `EFFICIENT_ATTENTION` (=2, from `AttentionBackend` enum).
Using a per-session provider option instead of an env var means both FA
and MEA test classes coexist in the same pytest process — each
InferenceSession creates its own CUDA EP with its own
`attention_kernel_options_`.

The existing `TestPagedAttention` class is skipped wholesale on sm<80 by
its `has_flash_attention()` gate, so without the new MEA class the
fallback path would have no CI coverage.

**Local verification** (NVIDIA A100 80GB, CUDA 12.8, GCC 13.3):

```
TestPagedAttention:       24/24 passed (~60s)   # FA baseline — no regression
TestPagedAttentionMEA:    24/24 passed (~59s)   # new MEA path
```

Tolerance: `rtol = atol = 5e-3` against the same torch reference used by
the FA parity test. All combinations match.

**sm<80 hardware coverage**: I don't have local Turing / Volta / Pascal
hardware, so real-SM coverage relies on MS CI. The code path exercised
on A100 via `sdpa_kernel=EFFICIENT_ATTENTION` is the same one taken on
sm<80; only the underlying CUTLASS kernel
(`run_memory_efficient_attention_sm50/70/75/80`) differs per SM, and
those are upstream and unmodified by this change.

**Build note**: built with `--cmake_extra_defines
CMAKE_CUDA_ARCHITECTURES=80 CMAKE_CXX_STANDARD=20`. The explicit C++20
define was needed because the initial configure resolved
`CMAKE_CXX_STANDARD=17`, under which `ort_version_check.h`'s `consteval`
usage fails to compile. Unrelated to this change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants