Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Priority is **1 = highest** (tried first).
| 2 | `FLASH_ATTN` |
| 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` |
| 5 | `TURBOQUANT` |

**Ampere/Hopper (SM 8.x-9.x):**

Expand All @@ -115,6 +116,7 @@ Priority is **1 = highest** (tried first).
| 2 | `FLASHINFER` |
| 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` |
| 5 | `TURBOQUANT` |

### MLA Attention (DeepSeek-style)

Expand Down
7 changes: 2 additions & 5 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,15 @@ def _get_backend_priorities(
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TURBOQUANT,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TURBOQUANT,
]
Comment thread
mgoin marked this conversation as resolved.


Expand Down Expand Up @@ -255,11 +257,6 @@ def get_valid_backends(
valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}

# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}

backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
device_capability,
Expand Down
Loading