Skip to content

[Attn,KV-cache] Use per-head scales in the attention selector#34281

Merged
LucasWilkinson merged 5 commits intovllm-project:mainfrom
eldarkurtic:use-perhead-scales-in-attn-selector
Feb 24, 2026
Merged

[Attn,KV-cache] Use per-head scales in the attention selector#34281
LucasWilkinson merged 5 commits intovllm-project:mainfrom
eldarkurtic:use-perhead-scales-in-attn-selector

Conversation

@eldarkurtic
Copy link
Contributor

@eldarkurtic eldarkurtic commented Feb 10, 2026

As requested by @MatthewBonanni and @LucasWilkinson in #30141 attention backends should be filtered during backend selection based on whether they support per-head attention quantization scales.

This enables early failure when a user attempts to load a model that requires per-head scales but no compatible attention backend is available.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the check for per-head attention quantization scales support. The changes introduce a mechanism to filter attention backends early during selection based on this capability. This is achieved by adding a requires_per_head_quant_scales parameter that is propagated down to the attention backend selector. The AttentionBackend class is updated with a supports_per_head_quant_scales method to facilitate this check, with FlashAttentionBackend correctly implementing it based on the FlashAttention version. The previous runtime assertion is removed, enabling earlier failure for unsupported configurations, which improves user experience. The changes are well-structured and correctly implemented.

@eldarkurtic eldarkurtic force-pushed the use-perhead-scales-in-attn-selector branch from b3cf85b to 656470c Compare February 10, 2026 22:13
Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, maybe we could add a case to the attention selector test to verify this is working?

@eldarkurtic
Copy link
Contributor Author

tests added, let me know if this is what you had in mind @MatthewBonanni

@mergify
Copy link

mergify bot commented Feb 11, 2026

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! Just one comment

with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
patch(supports_attr, return_value=supports_per_head),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By patching this I think the test isn't actually exercising the backend support, it'll just always pass as long as supports_per_head matches should_succeed in the test case. Can you get rid of this patch?

Suggested change
patch(supports_attr, return_value=supports_per_head),

],
)
def test_per_head_quant_scales_backend_selection(
backend_name: str, supports_per_head: bool, should_succeed: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below

Suggested change
backend_name: str, supports_per_head: bool, should_succeed: bool
backend_name: str, should_succeed: bool

@eldarkurtic
Copy link
Contributor Author

@MatthewBonanni great point, thanks a lot! I wanted to use that to distinguish between FA2 and FA3, but just found out that I can simply pass flash_attn_version when selecting the backend.
The test is updated and fixed now.

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 17, 2026
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@mergify mergify bot added new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Feb 19, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Feb 19, 2026
@mergify mergify bot added the cpu Related to CPU backends label Feb 19, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 19, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 19, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Feb 19, 2026
@mergify mergify bot added the tpu Related to Google TPUs label Feb 19, 2026
@mergify mergify bot added the kv-connector label Feb 19, 2026
@mergify
Copy link

mergify bot commented Feb 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eldarkurtic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 19, 2026
@eldarkurtic eldarkurtic force-pushed the use-perhead-scales-in-attn-selector branch from 0ec4a22 to a027512 Compare February 19, 2026 21:29
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Feb 19, 2026
@mergify
Copy link

mergify bot commented Feb 19, 2026

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Eldar Kurtic and others added 5 commits February 23, 2026 14:44
Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants