Fix Llama4 shape mismatch for 32k+ context window (#842)#855
Conversation
Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: vllm-project#680 and vllm-project#684 --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR fixes a shape mismatch issue in the Llama4 model when using context windows larger than 32k tokens. The problem occurs because temperature adjustment enabled for long contexts modifies the input tensor shape from 2D to 3D, but the output reshaping logic didn't account for this variation based on data parallelism configuration.
Changes:
- Modified the output reshaping logic in
UnquantizedFusedMoEMethod.forward_oot()to handle both 2D and 3D input tensors correctly based on thedp_sizevalue
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if layer.dp_size > 1: | ||
| return output.view(*(output.size(0), *input_shape[1:])) | ||
| else: |
There was a problem hiding this comment.
The conditional logic based on dp_size > 1 appears to be a workaround for handling different input shapes rather than addressing the root cause. Consider explicitly checking the input tensor dimensionality (len(input_shape)) to make the intent clearer and more maintainable. This would better document why different reshaping strategies are needed and make the code less fragile if dp_size semantics change.
| if layer.dp_size > 1: | |
| return output.view(*(output.size(0), *input_shape[1:])) | |
| else: | |
| if len(input_shape) == 2: | |
| # Handle 2D inputs where the leading dimension may have been | |
| # modified (e.g. by data parallel dispatch); keep the trailing | |
| # dimension(s) from the original shape and infer the leading one | |
| # from the actual output tensor. | |
| return output.view(output.size(0), *input_shape[1:]) | |
| else: | |
| # For higher-rank inputs, restore the original shape directly. |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
…llm-project#855) Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: vllm-project#680 and vllm-project#684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com>
…llm-project#855) Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: vllm-project#680 and vllm-project#684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com>
1. #805 2. #837 3. #855 4. #862 --------- Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com> Signed-off-by: linoy buchnik <lbuchnik@habana.ai> Signed-off-by: Iryna Boiko <iboiko@habana.ai> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: Linoy Buchnik <linoybu@gmail.com> Co-authored-by: Iryna Boiko <iboiko@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com>
…#855 (#881) Cherry pick missing fixes: chunked attention fixes from #821 llama4 32k+ context window #855 --------- Signed-off-by: Luca Calabria <luca.calabria@intel.com> Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com> Signed-off-by: linoy buchnik <lbuchnik@habana.ai> Signed-off-by: Iryna Boiko <iboiko@habana.ai> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Jakub Byczkowski <jbyczkowski@habana.ai> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Radosław Smyrek <radoslawx.smyrek@intel.com> Co-authored-by: Linoy Buchnik <linoybu@gmail.com> Co-authored-by: Iryna Boiko <iboiko@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com>
…llm-project#855) Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: vllm-project#680 and vllm-project#684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com> Signed-off-by: Wang, Zheng W <zheng.w.wang@intel.com>
1. vllm-project#805 2. vllm-project#837 3. vllm-project#855 4. vllm-project#862 --------- Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com> Signed-off-by: linoy buchnik <lbuchnik@habana.ai> Signed-off-by: Iryna Boiko <iboiko@habana.ai> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: Linoy Buchnik <linoybu@gmail.com> Co-authored-by: Iryna Boiko <iboiko@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com> Signed-off-by: slokesha <slokeshappa@habana.ai>
…ndow fix from vllm-project#855 (vllm-project#881) Cherry pick missing fixes: chunked attention fixes from vllm-project#821 llama4 32k+ context window vllm-project#855 --------- Signed-off-by: Luca Calabria <luca.calabria@intel.com> Signed-off-by: Jakub Byczkowski <jbyczkowski@habana.ai> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Radoslaw Smyrek <radoslawx.smyrek@intel.com> Signed-off-by: linoy buchnik <lbuchnik@habana.ai> Signed-off-by: Iryna Boiko <iboiko@habana.ai> Signed-off-by: Artur Fierka <artur.fierka@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Jakub Byczkowski <jbyczkowski@habana.ai> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Radosław Smyrek <radoslawx.smyrek@intel.com> Co-authored-by: Linoy Buchnik <linoybu@gmail.com> Co-authored-by: Iryna Boiko <iboiko@habana.ai> Co-authored-by: Artur Fierka <artur.fierka@intel.com> Signed-off-by: slokesha <slokeshappa@habana.ai>
…llm-project#855) Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: vllm-project#680 and vllm-project#684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com> Signed-off-by: slokesha <slokeshappa@habana.ai>
Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: #680 and #684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Llama4 for
max_model_len > 32kenable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensorqshape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing toUnqnatizedFusedMoEMetod -> forward: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3Doutput.viewbased on 2D output tensor.Found that following PR introduced the bug: #680 and #684
Cherry-picked from
releases/v0.13.0