Skip to content

[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8)#34243

Merged
mgoin merged 3 commits intovllm-project:mainfrom
eldarkurtic:fix-attn-quant-llama4
Feb 11, 2026
Merged

[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8)#34243
mgoin merged 3 commits intovllm-project:mainfrom
eldarkurtic:fix-attn-quant-llama4

Conversation

@eldarkurtic
Copy link
Copy Markdown
Contributor

@eldarkurtic eldarkurtic commented Feb 10, 2026

Llama-4 weights of q/k_proj are permuted during model loading to prepare the model for interleaved/gpt-neox rope.
The same permutation needs to be applied on quantization weight scales as well.

Purpose 1

So far, all of quantized Llama-4 models used to skip attention quantization as the model's accuracy would collapse. This was mistakenly interpreted as high sensitivity of the attention layers under quantization. Luckily for quantization, this is not the case and the accuracy collapse was caused by the mismatch in weights and their scales after weight-permutations for rope.

After this PR, we can safely quantize attention layers and maintain competitive accuracy relative to the unquantized baseline.

GSM8k evaluation command:

lm_eval --model vllm --model_args pretrained=MDL,add_bos_token=True,tensor_parallel_size=4,max_model_len=10000 --tasks gsm8k --num_fewshot 5 --batch_size auto

Results:

meta-llama/Llama-4-Scout-17B-16E-Instruct                                                   = 90.67
RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic (w/o attention quant)                   = 90.90
ekurtic/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic_wAttn (w/ attention quant, w/o this PR)  =  5.38
ekurtic/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic_wAttn (w/ attention quant, w/ this PR)   = 91.13

Purpose 2

Enable quantization of meta-llama/Llama-Guard-4-12B models. Not quantizing attention layers in Llama-4 MoEs wasn't a big penalty as those weights usually amount to 5-10% of the total weights compared to experts. However, Llama-Guard-4 models are built by replacing MoE part of Llama-4 with a standard dense layer, resulting in a standard decoder model where attention layers amount to a noticeable fraction of total weights (>10%) and therefore we would prefer to quantize them. This PR enables accurately quantizing all Linear layers (self-attn and MLPs) for Llama-Guard models, just as it is done for all other standard decoder-only architectures.

Note: for simplicity, this PR is a bugfix for INT8 and FP8 per-channel quant schemes. Patch for INT4 will come in a follow-up PR, as it requires slightly more complicated logic due to packed weights.

Signed-off-by: Your Name <you@example.com>
@mergify mergify bot added llama Related to Llama models bug Something isn't working labels Feb 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the bug where attention quantization of Llama-4 models led to accuracy collapse due to incorrect permutation of quantization scales. The changes correctly extend the weight permutation logic to handle scales for INT8 and FP8 quantization schemes, ensuring competitive accuracy relative to the unquantized baseline. This is a crucial fix for enabling quantized attention layers in Llama-4 and Llama-Guard-4 models.

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 10, 2026

Hi @eldarkurtic, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Your Name <you@example.com>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 10, 2026
Comment on lines -836 to +848
attn_out = self.config.hidden_size
attn_out = (
self.config.hidden_size
if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed, because CT transposes scales? Could we just always use the w.shape to decide?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that the original implementation was: attn_out = self.config.hidden_size and there is a comment above which says Do not rely on w's shape, as it may be in another layout. I didn't want to break this by doing attn_out = w.shape[-1].

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think this is safe to ignore, and always use attn_out = w.shape[-1], let me know

Comment on lines +877 to +886
is_ct_int8_or_fp8_weight_scale = False
if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat

is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid CT specific logic here. Why don't we apply to all weight scales? Is it because of concern over packed weights? I think they wouldn't work anyway

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added CT specific logic here for safety reasons. I don't know how other frameworks are shipping weight-scales

@mgoin mgoin merged commit 11c7ace into vllm-project:main Feb 11, 2026
54 checks passed
eldarkurtic added a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…cales for rope (int8, fp8) (vllm-project#34243)

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…cales for rope (int8, fp8) (vllm-project#34243)

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…cales for rope (int8, fp8) (vllm-project#34243)

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…cales for rope (int8, fp8) (vllm-project#34243)

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
## Summary

Cherry-pick upstream bug fixes for RHAIIS 3.3.1 onto `rhai/0.13.0`. All
fixes are from upstream vLLM `main` and address critical bugs affecting
RHAIIS 3.3.0. Other releases (3.2.2, EAx) will be done separately.

**Jira Epic:**
[INFERENG-4743](https://issues.redhat.com/browse/INFERENG-4743)

## Cherry-picked commits (chronological order)

| # | Upstream PR | Jira | Summary |
|---|------------|------|---------|
| 1 | [vllm-project#30550](vllm-project#30550) |
[INFERENG-5106](https://issues.redhat.com/browse/INFERENG-5106) |
Support using chat template as custom score template for reranking
models |
| 2 | [vllm-project#31406](vllm-project#31406) |
[INFERENG-4800](https://issues.redhat.com/browse/INFERENG-4800) | Add
encoder-only/cross attention support to Triton Attention backend |
| 3 | [vllm-project#34243](vllm-project#34243) |
[INFERENG-4746](https://issues.redhat.com/browse/INFERENG-4746) | Fix
Llama-4 attn quantization by correctly permuting scales for rope (int8,
fp8) |
| 4 | [vllm-project#34454](vllm-project#34454) |
[INFERENG-5032](https://issues.redhat.com/browse/INFERENG-5032) | Fix
structured output in multi-turn GPT-OSS (content:null with json_object)
|
| 5 | [vllm-project#34507](vllm-project#34507) |
[INFERENG-5038](https://issues.redhat.com/browse/INFERENG-5038) | Fix
fused MoE int32 overflow in stride*offset for large models |
| 6 | [vllm-project#35085](vllm-project#35085) |
[INFERENG-5028](https://issues.redhat.com/browse/INFERENG-5028) |
Gracefully disable AllReduceFusionPass on GPUs without multicast support
|
| 7 | [vllm-project#35456](vllm-project#35456) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) |
Replace assert with ValueError for response_format validation
(completions) |
| 8 | [vllm-project#35510](vllm-project#35510) |
[INFERENG-5035](https://issues.redhat.com/browse/INFERENG-5035) | Add
response_format validation to chat completions endpoint |


## Conflict resolutions

<details>
<summary><b>#1 — llama-nemotron-embed / score-template support
(vllm-project#30550)</b>: Clean cherry-pick, no conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#2 — Triton Attention (vllm-project#31406)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly onto `rhai/0.13.0`.
</details>

<details>
<summary><b>#3 — Llama-4 attn quant (vllm-project#34243)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly. 4 intermediate upstream commits touch `llama4.py` but
the fix targets a self-contained block.
</details>

<details>
<summary><b>vllm-project#4 — GPT-OSS multi-turn (vllm-project#34454)</b>: Clean cherry-pick, no
conflicts</summary>

Applied cleanly despite 3 intermediate upstream commits that refactored
imports in `gptoss_reasoning_parser.py`. The fix logic (adding
`eom_token_id` early-exit check in `is_reasoning_end`) was independent
of the import changes.
</details>

<details>
<summary><b>vllm-project#5 — Fused MoE int32 overflow (vllm-project#34507)</b>: Conflicts in 2
files</summary>

**`vllm/model_executor/layers/fused_moe/fused_moe.py`**: ~30
intermediate upstream commits refactored `fused_moe_kernel` with
conditional `naive_block_assignment` logic that doesn't exist in
`rhai/0.13.0`. Resolved by keeping our simpler code and applying only
the int64 cast fix:
- `fused_moe_kernel_gptq_awq`: added `.to(tl.int64)` to `tl.load()`
result
- `fused_moe_kernel`: added `offs_token = offs_token.to(tl.int64)`
before `token_mask`

**`tests/kernels/moe/test_moe.py`**: Upstream test changes depend on
`make_dummy_moe_config()` from intermediate refactors. Resolved by
keeping our existing test code (no test changes).
</details>

<details>
<summary><b>vllm-project#6 — AllReduceFusionPass multicast (vllm-project#35085)</b>: Conflict
due to file rename + API change</summary>

Upstream moved `collective_fusion.py` →
`compilation/passes/fusion/allreduce_rms_fusion.py` and changed the API
from `trtllm_create_ipc_workspace_for_all_reduce_fusion()` to
`create_allreduce_fusion_workspace()`. Resolved by applying the
try/except wrapper around our existing
`trtllm_create_ipc_workspace_for_all_reduce_fusion()` call in
`collective_fusion.py`. The error handling logic (catching RuntimeError
with "multicast" in message, logging warning, returning early) is
identical to upstream.
</details>

<details>
<summary><b>vllm-project#7 — response_format validation for completions
(vllm-project#35456)</b>: Conflict due to file restructuring</summary>

Upstream split `protocol.py` into `completion/protocol.py` and
`chat_completion/protocol.py`. Our branch still has the monolithic
`protocol.py`. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`CompletionRequest` in our `protocol.py`
- Using `ValueError` instead of upstream's `VLLMValidationError` (which
doesn't exist in our branch; `ValueError` is already handled as 400 Bad
Request in `serving_engine.py`)
- Test additions from upstream applied cleanly to
`test_completion_error.py`
</details>

<details>
<summary><b>vllm-project#8 — response_format validation for chat completions
(vllm-project#35510)</b>: Conflict due to file restructuring</summary>

Same file restructuring issue as vllm-project#6. Resolved by:
- Removing the non-existent
`vllm/entrypoints/openai/chat_completion/protocol.py`
- Manually adding `validate_response_format` model_validator to
`ChatCompletionRequest` in our `protocol.py`
- Only accepting the `test_json_schema_response_format_missing_schema`
test from the conflict (discarding ~140 lines of intermediate upstream
tests that reference non-existent paths in our branch)
</details>

## Test plan

- [ ] Verify `llama-nemotron-embed-1b-v2` works correctly with the
backported score-template / bidirectional model support
- [ ] Verify Llama-4 quantized model loads correctly with int8/fp8
attention quantization
- [ ] Verify GPT-OSS multi-turn chat with `json_object` response_format
returns valid content
- [ ] Verify large MoE models (e.g. Qwen3.5-397B) don't crash with int32
overflow
- [ ] Verify MoE model loading on H200 GPUs (without multicast)
gracefully falls back
- [ ] Verify `response_format: {type: "json_schema"}` without
`json_schema` field returns 400 (not 500) for both `/v1/completions` and
`/v1/chat/completions`
- [ ] Verify encoder models (e.g. Whisper) work with Triton attention
backend on ROCm


[INFERENG-4743]:
https://redhat.atlassian.net/browse/INFERENG-4743?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4800]:
https://redhat.atlassian.net/browse/INFERENG-4800?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-4746]:
https://redhat.atlassian.net/browse/INFERENG-4746?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5032]:
https://redhat.atlassian.net/browse/INFERENG-5032?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
[INFERENG-5038]:
https://redhat.atlassian.net/browse/INFERENG-5038?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

[INFERENG-5106]:
https://redhat.atlassian.net/browse/INFERENG-5106?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants