Skip to content

[HPU] Fix FP8 block-to-channel conversion breaking MLA weight loading#1220

Merged
adobrzyn merged 4 commits intomainfrom
afierka/fix-fp8-mla-and-torchaudio-import
Mar 25, 2026
Merged

[HPU] Fix FP8 block-to-channel conversion breaking MLA weight loading#1220
adobrzyn merged 4 commits intomainfrom
afierka/fix-fp8-mla-and-torchaudio-import

Conversation

@afierka-intel
Copy link
Copy Markdown
Collaborator

@afierka-intel afierka-intel commented Mar 23, 2026

Summary

Fix FP8 block-to-channel conversion that leaves stale weight_block_size, crashing MLA weight loading for DeepSeek-R1 and similar models on HPU.

Problem

When VLLM_HPU_FORCE_CHANNEL_FP8=True (default), fp8_block_linear_postprocess_weights in ops.py converts block-quantized FP8 weights to channel-wise FP8. After conversion, weight_scale_inv becomes 1D [N_out] (per-channel), but weight_block_size is not cleared (remains [128, 128]).

Any downstream code that uses weight_block_size as a group_shape for dequantization (e.g. MLAAttention.process_weights_after_loadingscaled_dequantize) fails:

ValueError: 1D scale with shape torch.Size([4096]) cannot be broadcast
to x with shape torch.Size([4096, 512]), group_shape=(128, 128)

Fix

Two changes:

  1. Root cause (vllm_gaudi/extension/ops.py): Clear layer.weight_block_size = None after block→channel conversion in fp8_block_linear_postprocess_weights. This fixes the issue for all layers, not just MLA.

  2. MLA-specific handling (vllm_gaudi/attention/oot_mla.py): HPUMLAAttention.process_weights_after_loading now handles FP8 kv_b_proj dequantization directly for both:

    • Channel-wise (1D scale) — broadcast multiply with act_dtype
    • Block (2D scale)dequant_block_fp8_weight_naive with explicit dtype=act_dtype

    Non-FP8 models are unaffected (falls through to upstream logic).

Testing

Tested on g2.l pod (8× Gaudi2 HL-225, 98304 MiB each) with:

  • vllm 0.18.1rc1 (commit aec2dc6c0)
  • vllm-gaudi main branch
  • Model: DeepSeek-R1 (FP8 block quant, weight_block_size=[128,128])
  • TP=8, gpu_memory_utilization=0.9

Results

  • step-2-measure-scales.py with mlperf dataset — completed successfully
  • step-2-measure-scales.py with NeelNanda/pile-10k (512 samples) — completed without OOM
  • ✅ Model loads and runs inference correctly

Affected models

Any model using MLA attention + FP8 block quantization on HPU:

  • DeepSeek-R1
  • DeepSeek-V3 (FP8 variants)
  • Other DeepSeek V2/V3 architecture models with FP8 checkpoints

@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@afierka-intel afierka-intel force-pushed the afierka/fix-fp8-mla-and-torchaudio-import branch from 8c66943 to e9a9525 Compare March 23, 2026 14:16
@afierka-intel afierka-intel changed the title [HPU] Fix FP8 MLA kv_b_proj dequantization and torchaudio import crash [HPU] Fix FP8 MLA kv_b_proj dequantization for block-quantized models Mar 23, 2026
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

When VLLM_HPU_FORCE_CHANNEL_FP8=True (default), block-quantized FP8
weights are converted to channel-wise FP8 by hpu_fp8.py. This makes
weight_scale_inv 1D [N_out], but weight_block_size is not cleared.
The upstream MLAAttention.process_weights_after_loading then calls
scaled_dequantize with group_shape=weight_block_size, causing:
  ValueError: 1D scale with shape [4096] cannot be broadcast to x
  with shape [4096, 512], group_shape=(128, 128)

Fix: HPUMLAAttention now handles FP8 kv_b_proj dequantization
directly for both channel-wise (1D) and block (2D) scale formats.

Tested with DeepSeek-R1 FP8 on g2.l (8x Gaudi2, TP=8) using
calibration step-2-measure-scales.py with both mlperf and
NeelNanda/pile-10k datasets.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
@afierka-intel afierka-intel force-pushed the afierka/fix-fp8-mla-and-torchaudio-import branch from e9a9525 to c0eedba Compare March 23, 2026 14:18
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses two HPU startup/runtime blockers for MLA + FP8 block-quantized models (e.g., DeepSeek-R1): a kv_b_proj dequantization failure when block-FP8 is converted to channel-wise FP8, and an eager torchaudio import crash during model registry initialization.

Changes:

  • Override MLA weight post-processing on HPU to directly dequantize kv_b_proj correctly for both channel-wise and block FP8 scales.
  • Add a safe model-registration helper to avoid vLLM startup crashes when optional model dependencies (e.g., torchaudio) are unavailable.

Reviewed changes

Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.

File Description
vllm_gaudi/models/__init__.py Adds _safe_register() to skip (with warning) model registrations that fail due to missing optional deps.
vllm_gaudi/attention/oot_mla.py Adds HPU-specific kv_b_proj dequantization logic to avoid scale shape mismatch after block→channel FP8 conversion.

Comment thread vllm_gaudi/attention/oot_mla.py
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
9ace378a63cac49129829ac30e9645bb8af4d2d5

@adobrzyn
Copy link
Copy Markdown
Collaborator

Finding 1 🟡 Medium · oot_mla.py — general approach

The root cause is in fp8_block_linear_postprocess_weights, not in MLA.

The real bug is that fp8_block_linear_postprocess_weights (called by Fp8LinearMethod.process_weights_after_loading) converts block-quantized FP8 to channel-wise FP8 and updates weight_scale_inv to 1D, but does not clear weight_block_size. This stale weight_block_size=[128,128] then causes scaled_dequantize to fail with a shape mismatch.

Fixing this in fp8_block_linear_postprocess_weights (e.g., layer.weight_block_size = None after channel conversion) would fix the issue for all layers that call get_and_maybe_dequant_weights, not just MLA's kv_b_proj. The current fix is correct and self-contained, but any future layer that hits the same path will need its own workaround.

Suggestion: Consider also adding a one-liner to fp8_block_linear_postprocess_weights (when force_channel_fp8=True):

layer.weight_block_size = None  # scale is now per-channel, not per-block

This could be done in a follow-up PR or alongside this one. At minimum, add a # TODO comment in the HPU MLA override explaining why the root cause fix isn't applied here.


[- Reviewed by Awesome ChlOpus]

@adobrzyn
Copy link
Copy Markdown
Collaborator

Finding 2 🟡 Medium · oot_mla.py:L178

Block FP8 path uses hardcoded dtype=torch.bfloat16 instead of act_dtype.

In the channel-wise FP8 branch, the dequantization correctly uses act_dtype:

ws = weight_scale_inv.view(-1, 1).to(act_dtype)
kv_b_proj_weight = (weight.to(act_dtype) * ws).T

But in the block FP8 branch, dequant_block_fp8_weight_naive is called without specifying dtype, defaulting to torch.bfloat16:

kv_b_proj_weight = dequant_block_fp8_weight_naive(
    weight,
    weight_scale_inv,
    kv_b_proj.weight_block_size,
    original_M=orig_M,
    original_N=orig_N,
    do_unpad=(orig_M is not None),
).T

If act_dtype is ever not bfloat16 (e.g., float16), the block path would produce bfloat16 tensors while the rest of the model expects act_dtype.

Suggestion: Pass dtype=act_dtype explicitly:

kv_b_proj_weight = dequant_block_fp8_weight_naive(
    weight,
    weight_scale_inv,
    kv_b_proj.weight_block_size,
    dtype=act_dtype,
    original_M=orig_M,
    original_N=orig_N,
    do_unpad=(orig_M is not None),
).T

[- Reviewed by Awesome ChlOpus]

- Pass dtype=act_dtype to dequant_block_fp8_weight_naive in the block
  FP8 path to avoid hardcoded bfloat16 assumption.
- Clear layer.weight_block_size=None in fp8_block_linear_postprocess_weights
  after block-to-channel conversion, fixing the root cause for all layers.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@afierka-intel afierka-intel changed the title [HPU] Fix FP8 MLA kv_b_proj dequantization for block-quantized models [HPU] Fix FP8 block-to-channel conversion breaking MLA weight loading Mar 24, 2026
@afierka-intel
Copy link
Copy Markdown
Collaborator Author

Addressed review findings from @adobrzyn and Copilot reviewer in commit cc70fef:

  1. Root cause fix (ops.py): Added layer.weight_block_size = None in fp8_block_linear_postprocess_weights after block→channel conversion. This fixes the stale weight_block_size issue for all layers, not just MLA. (Finding 1)

  2. dtype=act_dtype (oot_mla.py): Passed explicit dtype=act_dtype to dequant_block_fp8_weight_naive in the block FP8 path, instead of relying on the hardcoded bfloat16 default. (Finding 2 / Copilot review)

  3. Branch name mismatch (Finding 3): intentional — torchaudio fix was split out to a separate PR.

Also updated PR title and description to reflect both changes.

@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
9ace378a63cac49129829ac30e9645bb8af4d2d5

@adobrzyn adobrzyn merged commit 6afc171 into main Mar 25, 2026
75 checks passed
yeonsily pushed a commit to yeonsily/vllm-gaudi that referenced this pull request Mar 27, 2026
…vllm-project#1220)

## Summary

Fix FP8 block-to-channel conversion that leaves stale
`weight_block_size`, crashing MLA weight loading for DeepSeek-R1 and
similar models on HPU.

## Problem

When `VLLM_HPU_FORCE_CHANNEL_FP8=True` (default),
`fp8_block_linear_postprocess_weights` in `ops.py` converts
block-quantized FP8 weights to channel-wise FP8. After conversion,
`weight_scale_inv` becomes 1D `[N_out]` (per-channel), but
**`weight_block_size` is not cleared** (remains `[128, 128]`).

Any downstream code that uses `weight_block_size` as a `group_shape` for
dequantization (e.g. `MLAAttention.process_weights_after_loading` →
`scaled_dequantize`) fails:

```
ValueError: 1D scale with shape torch.Size([4096]) cannot be broadcast
to x with shape torch.Size([4096, 512]), group_shape=(128, 128)
```

## Fix

Two changes:

1. **Root cause** (`vllm_gaudi/extension/ops.py`): Clear
`layer.weight_block_size = None` after block→channel conversion in
`fp8_block_linear_postprocess_weights`. This fixes the issue for all
layers, not just MLA.

2. **MLA-specific handling** (`vllm_gaudi/attention/oot_mla.py`):
`HPUMLAAttention.process_weights_after_loading` now handles FP8
`kv_b_proj` dequantization directly for both:
   - **Channel-wise (1D scale)** — broadcast multiply with `act_dtype`
- **Block (2D scale)** — `dequant_block_fp8_weight_naive` with explicit
`dtype=act_dtype`

   Non-FP8 models are unaffected (falls through to upstream logic).

## Testing

Tested on g2.l pod (8× Gaudi2 HL-225, 98304 MiB each) with:
- **vllm** 0.18.1rc1 (commit `aec2dc6c0`)
- **vllm-gaudi** main branch
- **Model:** DeepSeek-R1 (FP8 block quant,
`weight_block_size=[128,128]`)
- **TP=8**, `gpu_memory_utilization=0.9`

### Results
- ✅ `step-2-measure-scales.py` with mlperf dataset — completed
successfully
- ✅ `step-2-measure-scales.py` with NeelNanda/pile-10k (512 samples) —
completed without OOM
- ✅ Model loads and runs inference correctly

## Affected models

Any model using MLA attention + FP8 block quantization on HPU:
- DeepSeek-R1
- DeepSeek-V3 (FP8 variants)
- Other DeepSeek V2/V3 architecture models with FP8 checkpoints

---------

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
adobrzyn pushed a commit that referenced this pull request Mar 31, 2026
…#1220)

Fix FP8 block-to-channel conversion that leaves stale
`weight_block_size`, crashing MLA weight loading for DeepSeek-R1 and
similar models on HPU.

When `VLLM_HPU_FORCE_CHANNEL_FP8=True` (default),
`fp8_block_linear_postprocess_weights` in `ops.py` converts
block-quantized FP8 weights to channel-wise FP8. After conversion,
`weight_scale_inv` becomes 1D `[N_out]` (per-channel), but
**`weight_block_size` is not cleared** (remains `[128, 128]`).

Any downstream code that uses `weight_block_size` as a `group_shape` for
dequantization (e.g. `MLAAttention.process_weights_after_loading` →
`scaled_dequantize`) fails:

```
ValueError: 1D scale with shape torch.Size([4096]) cannot be broadcast
to x with shape torch.Size([4096, 512]), group_shape=(128, 128)
```

Two changes:

1. **Root cause** (`vllm_gaudi/extension/ops.py`): Clear
`layer.weight_block_size = None` after block→channel conversion in
`fp8_block_linear_postprocess_weights`. This fixes the issue for all
layers, not just MLA.

2. **MLA-specific handling** (`vllm_gaudi/attention/oot_mla.py`):
`HPUMLAAttention.process_weights_after_loading` now handles FP8
`kv_b_proj` dequantization directly for both:
   - **Channel-wise (1D scale)** — broadcast multiply with `act_dtype`
- **Block (2D scale)** — `dequant_block_fp8_weight_naive` with explicit
`dtype=act_dtype`

   Non-FP8 models are unaffected (falls through to upstream logic).

Tested on g2.l pod (8× Gaudi2 HL-225, 98304 MiB each) with:
- **vllm** 0.18.1rc1 (commit `aec2dc6c0`)
- **vllm-gaudi** main branch
- **Model:** DeepSeek-R1 (FP8 block quant,
`weight_block_size=[128,128]`)
- **TP=8**, `gpu_memory_utilization=0.9`

- ✅ `step-2-measure-scales.py` with mlperf dataset — completed
successfully
- ✅ `step-2-measure-scales.py` with NeelNanda/pile-10k (512 samples) —
completed without OOM
- ✅ Model loads and runs inference correctly

Any model using MLA attention + FP8 block quantization on HPU:
- DeepSeek-R1
- DeepSeek-V3 (FP8 variants)
- Other DeepSeek V2/V3 architecture models with FP8 checkpoints

---------

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
afierka-intel pushed a commit that referenced this pull request Apr 14, 2026
Rename FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Copilot AI pushed a commit that referenced this pull request Apr 14, 2026
Rename FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Co-authored-by: michalkuligowski <23379006+michalkuligowski@users.noreply.github.com>
mgawarkiewicz-intel pushed a commit that referenced this pull request Apr 20, 2026
…or v0.19.0 (#1374)

Renames FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Soila Kavulya <soila.p.kavulya@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants