Skip to content

[HMA]Move hybrid blksize to update_block_size_for_backend to fix attn supported block size is not 16 issue#37467

Merged
MatthewBonanni merged 29 commits intovllm-project:mainfrom
xuechendi:wip_nemotron_h_xpu
Mar 30, 2026
Merged

[HMA]Move hybrid blksize to update_block_size_for_backend to fix attn supported block size is not 16 issue#37467
MatthewBonanni merged 29 commits intovllm-project:mainfrom
xuechendi:wip_nemotron_h_xpu

Conversation

@xuechendi
Copy link
Copy Markdown
Collaborator

@xuechendi xuechendi commented Mar 18, 2026

Purpose

--

Issue description:

Testing nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-bf16 with default block_size = 64.

if FA block_size = 16 w/ TP4 => PASS: 540672 / 8192 = 66

[kv_cache_utils.py:934] layer_spec.block_size=16, layer_spec.page_size_bytes=8192
[kv_cache_utils.py:934] layer_spec.block_size=262144, layer_spec.page_size_bytes=540672

if FA block_size = 64 w/ TP4 => ERROR!!! SSM_page_size will not be evenly divided by FA_page_size: 540672/32768 = 16.5

[kv_cache_utils.py:934] layer_spec.block_size=64, layer_spec.page_size_bytes=32768
[kv_cache_utils.py:934] layer_spec.block_size=262144, layer_spec.page_size_bytes=540672

--

Root Cause:

Default hybrid model block_size alignment is happened before platform.check_and_update_config
sequence is
platform init =>cache_config init=> model_config init(hybrid model block_size alignment) => __post_init => platform.check_and_update_config

because of the sequence, current alignment for hybrid model is only calculated based on block_size=16. While XPU block_size will be updated to 64 in platform.check_and_update_config, which cause the unalignment issue.

--

Solution:

After discussion with Reviewers in this PR, suggested to update block_size in platform.update_block_size_for_backend for XPU.

Including:

  • redo hybrid model block_size alignment.
  • update block_size per attn_backend.

--

Test Plan - re-tested on 0323 with latest fix

validation after fix

case 0 - Hybrid model

lm_eval   \
--model vllm   \
--model_args pretrained=/mnt/data/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-bf16,tensor_parallel_size=4,trust_remote_code=True,enable_expert_parallel=True,attention_backend=TRITON_ATTN   \
--tasks gsm8k   \
--num_fewshot 5   \
--batch_size auto
(Worker_TP0_EP0 pid=250806) INFO 03-23 20:58:48 [default_loader.py:384] Loading weights took 5.72 seconds
(Worker_TP0_EP0 pid=250806) INFO 03-23 20:58:49 [gpu_model_runner.py:4596] Model loading took 14.76 GiB memory and 7.550555 seconds
(Worker_TP0_EP0 pid=250806) INFO 03-23 20:58:49 [xpu.py:287] Update hybrid model block size to 1088
(Worker_TP2_EP2 pid=250808) INFO 03-23 20:58:49 [xpu.py:287] Update hybrid model block size to 1088
(Worker_TP1_EP1 pid=250807) INFO 03-23 20:58:50 [xpu.py:287] Update hybrid model block size to 1088
(Worker_TP3_EP3 pid=250809) INFO 03-23 20:58:50 [xpu.py:287] Update hybrid model block size to 1088
image

case 1 - full attn model

 lm_eval   \
--model vllm   \
--model_args pretrained=Qwen/Qwen3-30B-A3B-Thinking-2507,tensor_parallel_size=4,trust_remote_code=True,enable_expert_parallel=True   \
--tasks gsm8k   \
--num_fewshot 5  \
--batch_size auto
(Worker_TP0_EP0 pid=251994) INFO 03-23 22:12:58 [default_loader.py:384] Loading weights took 16.17 seconds
(Worker_TP1_EP1 pid=251995) INFO 03-23 22:12:59 [xpu.py:306] Update FLASH_ATTN block size to 64
(Worker_TP0_EP0 pid=251994) INFO 03-23 22:12:59 [gpu_model_runner.py:4596] Model loading took 14.3 GiB memory and 133.598931 seconds
(Worker_TP0_EP0 pid=251994) INFO 03-23 22:12:59 [xpu.py:306] Update FLASH_ATTN block size to 64
(Worker_TP2_EP2 pid=251996) INFO 03-23 22:12:59 [xpu.py:306] Update FLASH_ATTN block size to 64
(Worker_TP3_EP3 pid=251997) INFO 03-23 22:12:59 [xpu.py:306] Update FLASH_ATTN block size to 64
image

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the KV cache page size unification logic to handle models with non-integral page size ratios between layers. It replaces the use of max() with math.lcm() to determine the unified page size, which is a more robust approach for this scenario. The changes also correctly propagate this scaling to page_size_padded if it is present in the cache specification. While this change is correct, I've identified a potential issue where the calculated LCM could become excessively large, leading to high memory consumption. I've added a high-severity comment with a suggestion to add a safeguard. An unrelated change for XPU support is also included in vllm/model_executor/layers/fused_moe/layer.py.

Comment thread vllm/v1/core/kv_cache_utils.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 18, 2026

Hi @xuechendi, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@xuechendi xuechendi changed the title [HMA][KV_CACHE_UTILS] propose lcm for hybrid page_size to handle non-integral ratio - test by Nemotron [HMA][KV_CACHE_UTILS] Fix for hybrid page_size to handle non-integral ratio Mar 18, 2026
@xuechendi xuechendi force-pushed the wip_nemotron_h_xpu branch from c954497 to 32dcef7 Compare March 18, 2026 19:06
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 18, 2026

Hi @xuechendi, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@xuechendi xuechendi force-pushed the wip_nemotron_h_xpu branch 2 times, most recently from c523883 to b724691 Compare March 18, 2026 21:00
@xuechendi xuechendi changed the title [HMA][KV_CACHE_UTILS] Fix for hybrid page_size to handle non-integral ratio [HMA]Fix corner case when hybrid page_size can not be evenly divided issue Mar 18, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 18, 2026

Hi @xuechendi, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 19, 2026

There is another fix at #37429 and I think that one is considering more?

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things off the top of my head

  • lm_eval scores are lower than they should be for this model (iirc we're at 80+ for nemotronH @ZhanqiuHu )
  • what's the setup that is leading to this assert? I don't recall seeing this issue before.
    In particular

has two page_size for linear_attn and full_attn

this is only true when kernel_block_size is required, and then again in that case the page_size is divided by the logical/physical ratio (duplicating num_blocks accordingly).

Let's look deeper into the cause of this.

@xuechendi
Copy link
Copy Markdown
Collaborator Author

xuechendi commented Mar 19, 2026

@NickLucche

this is only true when kernel_block_size is required, and then again in that case the page_size is divided by the logical/physical ratio (duplicating num_blocks accordingly).

Oh, I am testing with Intel GPU which uses block_size = 64

Just tested with block_size = 16, and we won't need this FIX. I'll update the title and purpose


for accuracy, I am following the config here - https://github.com/vllm-project/vllm/blob/main/.buildkite/lm-eval-harness/configs/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.yaml
image

If using full set (no limit) and without multiturn, strict-match acc gets to 0.837. Wondering if you're testing with lm-eval-harness/configs/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.yaml or using a different gsm8k setting?

@xuechendi xuechendi changed the title [HMA]Fix corner case when hybrid page_size can not be evenly divided issue [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64) Mar 19, 2026
@ZhanqiuHu
Copy link
Copy Markdown
Contributor

ZhanqiuHu commented Mar 19, 2026

If using full set (no limit) and without multiturn, strict-match acc gets to 0.837. Wondering if you're testing with lm-eval-harness/configs/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.yaml or using a different gsm8k setting?

Yes, without apply_chat_template and fewshot_as_multiturn, the expected acc is ~0.84.

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

ZhanqiuHu commented Mar 19, 2026

@xuechendi btw was this function called: verify_and_update_config?

@xuechendi xuechendi changed the title [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64) [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) Mar 19, 2026
@xuechendi xuechendi force-pushed the wip_nemotron_h_xpu branch 2 times, most recently from 9192ad6 to d8c6c7e Compare March 19, 2026 22:34
@xuechendi
Copy link
Copy Markdown
Collaborator Author

we'd better rename title before merge.

Any suggestion?

@jikunshang
Copy link
Copy Markdown
Collaborator

something like Refactor block size selection/update and fix hybrid model alignment

@xuechendi xuechendi changed the title [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) [HMA]Move hybrid blksize to update_block_size_for_backend to fix attn supported block size is not 16 issue Mar 28, 2026
@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 30, 2026

@xuechendi I found on platform XPU, the final block size is 64 dividable using FLASH_ATTN but not using TRITON_ATTN, which will cause gdn attn incorrect output. How can we assure the final block size is 64 dividable no matter what the attention backend is.

INFO 03-30 07:05:40 [config.py:228] Setting attention block size to 528 tokens to ensure that attention page size is >= mamba page size.
INFO 03-30 07:05:40 [config.py:259] Padding mamba page size by 0.76% to ensure that mamba page size and attention page size are exactly equal.

@xuechendi
Copy link
Copy Markdown
Collaborator Author

@xuechendi I found on platform XPU, the final block size is 64 dividable using FLASH_ATTN but not using TRITON_ATTN, which will cause gdn attn incorrect output. How can we assure the final block size is 64 dividable no matter what the attention backend is.

INFO 03-30 07:05:40 [config.py:228] Setting attention block size to 528 tokens to ensure that attention page size is >= mamba page size.
INFO 03-30 07:05:40 [config.py:259] Padding mamba page size by 0.76% to ensure that mamba page size and attention page size are exactly equal.

I did same test for Triton, and since Triton should not limit block_size to 64, the logic here is to set as 16, not sure why you're seeing misalign, did you set triton_attn block_size=64 somewhere in your codes?

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 30, 2026

@xuechendi I found on platform XPU, the final block size is 64 dividable using FLASH_ATTN but not using TRITON_ATTN, which will cause gdn attn incorrect output. How can we assure the final block size is 64 dividable no matter what the attention backend is.

INFO 03-30 07:05:40 [config.py:228] Setting attention block size to 528 tokens to ensure that attention page size is >= mamba page size.
INFO 03-30 07:05:40 [config.py:259] Padding mamba page size by 0.76% to ensure that mamba page size and attention page size are exactly equal.

I did same test for Triton, and since Triton should not limit block_size to 64, the logic here is to set as 16, not sure why you're seeing misalign, did you set triton_attn block_size=64 somewhere in your codes?

TRITON_ATTN has no limitation on block size to be dividable by 64, but in gdn attention, we do computation leveraging on kv cache so it's still need block-size dividable by 64.

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 30, 2026

@xuechendi I found on platform XPU, the final block size is 64 dividable using FLASH_ATTN but not using TRITON_ATTN, which will cause gdn attn incorrect output. How can we assure the final block size is 64 dividable no matter what the attention backend is.

INFO 03-30 07:05:40 [config.py:228] Setting attention block size to 528 tokens to ensure that attention page size is >= mamba page size.
INFO 03-30 07:05:40 [config.py:259] Padding mamba page size by 0.76% to ensure that mamba page size and attention page size are exactly equal.

I did same test for Triton, and since Triton should not limit block_size to 64, the logic here is to set as 16, not sure why you're seeing misalign, did you set triton_attn block_size=64 somewhere in your codes?

TRITON_ATTN has no limitation on block size to be dividable by 64, but in gdn attention, we do computation leveraging on kv cache so it still need block-size dividable by 64. This is XPU gdn attention limitation for the time being.

@xuechendi
Copy link
Copy Markdown
Collaborator Author

@xuechendi I found on platform XPU, the final block size is 64 dividable using FLASH_ATTN but not using TRITON_ATTN, which will cause gdn attn incorrect output. How can we assure the final block size is 64 dividable no matter what the attention backend is.

INFO 03-30 07:05:40 [config.py:228] Setting attention block size to 528 tokens to ensure that attention page size is >= mamba page size.
INFO 03-30 07:05:40 [config.py:259] Padding mamba page size by 0.76% to ensure that mamba page size and attention page size are exactly equal.

I did same test for Triton, and since Triton should not limit block_size to 64, the logic here is to set as 16, not sure why you're seeing misalign, did you set triton_attn block_size=64 somewhere in your codes?

TRITON_ATTN has no limitation on block size to be dividable by 64, but in gdn attention, we do computation leveraging on kv cache so it's still need block-size dividable by 64.

Share more details? Is that specific to current xpu_kv_cache impl? If that is only for GDN, maybe we can fix in the Qwen3.5 PR since I am not able to test it right now?

@MatthewBonanni
Copy link
Copy Markdown
Collaborator

Since #33657 hasn't landed yet, I think this PR is fine as is for now. If GDN does truly require block size 64 for all backends on XPU, that change can be made as part of 33657

@MatthewBonanni MatthewBonanni requested a review from ZJY0516 as a code owner March 30, 2026 14:29
@MatthewBonanni MatthewBonanni merged commit 3b1dbaa into vllm-project:main Mar 30, 2026
73 of 76 checks passed
neweyes pushed a commit to neweyes/vllm that referenced this pull request Mar 31, 2026
…issue (blk_size=64,tp=4) (vllm-project#37467)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: neweyes <328719365@qq.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…issue (blk_size=64,tp=4) (vllm-project#37467)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
ccrhx4 pushed a commit to ccrhx4/huanxing.vllm-fork that referenced this pull request Apr 9, 2026
…issue (blk_size=64,tp=4) (vllm-project#37467)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…issue (blk_size=64,tp=4) (vllm-project#37467)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Apr 10, 2026
…xtral, MoE and Granite regressions (#1311)

## Summary
This PR fixes a set of regressions introduced by recent upstream changes
and observed in vLLM-Gaudi hourly validation.

The branch now includes:
- Pixtral HPUAttention projection path fix,
- MoE dispatch and method override alignment updates for fused MoE and
compressed tensors,
- unit test updates to match the new MoE runner API usage,
- fix hybrid model page size alignment for Granite 4.0-H.

## Related upstream PRs that introduced the regressions
- vllm-project/vllm#37234
- vllm-project/vllm#35153
- vllm-project/vllm#36963
- vllm-project/vllm#38960
- vllm-project/vllm#35326
- vllm-project/vllm#37467

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

intel-gpu Related to Intel GPU ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants