Skip to content

Cherry-pick: Updated fix regression in Mistral-Large-3-675B (#1304) for v0.19.0#1374

Merged
mgawarkiewicz-intel merged 8 commits intovllm-project:releases/v0.19.0from
skavulya:skavulya/mistral3-rename-scales-0.19.0
Apr 20, 2026
Merged

Cherry-pick: Updated fix regression in Mistral-Large-3-675B (#1304) for v0.19.0#1374
mgawarkiewicz-intel merged 8 commits intovllm-project:releases/v0.19.0from
skavulya:skavulya/mistral3-rename-scales-0.19.0

Conversation

@skavulya
Copy link
Copy Markdown
Contributor

Renames FP8 blockwise compressed tensors scales to match HPU ops, Fixes regression in https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due to #1220 and #1053

Copilot AI review requested due to automatic review settings April 17, 2026 23:01
@github-actions
Copy link
Copy Markdown

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

skavulya and others added 8 commits April 17, 2026 16:04
Rename FP8 blockwise compressed tensors scales to match HPU ops

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Soila Kavulya <soila.p.kavulya@intel.com>
@skavulya skavulya force-pushed the skavulya/mistral3-rename-scales-0.19.0 branch from 1e099b3 to 61ae8c6 Compare April 17, 2026 23:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Cherry-pick to v0.19.0 that fixes an HPU FP8 regression affecting Mistral-Large-3-675B (and similar) by aligning blockwise FP8 scale tensor naming/handling with HPU ops and ensuring the relevant post-processing paths behave correctly.

Changes:

  • Add a post-load scale “alias/rename” helper and update FP8 dequant/apply paths to prefer *_scale_inv where required by HPU ops (Linear + MoE).
  • Adjust block FP8 weight handling to route through HPU block FP8 linear application and avoid shape/layout mismatches.
  • Add/extend unit tests covering FP8 block-quantized Linear and MoE flows; make MLA kv_b_proj weights contiguous in the MLA HPU path.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
vllm_gaudi/ops/hpu_compressed_tensors.py Introduces scale aliasing to *_scale_inv, updates FP8 block/channel execution paths, and adds MoE quant-config plumbing for renamed scale attributes.
vllm_gaudi/attention/oot_mla.py Ensures dequantized/transposed kv_b_proj weights are contiguous to avoid runtime overhead/issues.
tests/unit_tests/ops/test_hpu_compressed_tensors.py Adds new unit tests for FP8 block-quantized Linear and MoE, focused on the new scale naming + post-processing behavior.

Comment on lines +71 to +72
scale = scale.data if isinstance(scale, torch.nn.Parameter) else scale
layer.register_parameter(hpu_scale_name, torch.nn.Parameter(scale, requires_grad=False))
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_hpu_weight_scale_alias uses .data to extract the underlying tensor before re-wrapping it in a new torch.nn.Parameter. Using .data is discouraged and can bypass autograd safety checks; since these scales are already non-trainable parameters, it’s safer to either move/rename the existing Parameter (preserving its subclass/metadata) or use detach() when creating the new Parameter.

Suggested change
scale = scale.data if isinstance(scale, torch.nn.Parameter) else scale
layer.register_parameter(hpu_scale_name, torch.nn.Parameter(scale, requires_grad=False))
if isinstance(scale, torch.nn.Parameter):
layer.register_parameter(hpu_scale_name, scale)
else:
aliased_scale = scale.detach() if isinstance(scale, torch.Tensor) else scale
layer.register_parameter(
hpu_scale_name,
torch.nn.Parameter(aliased_scale, requires_grad=False),
)

Copilot uses AI. Check for mistakes.
Comment on lines +453 to +454
weight_fp32 = torch.randn(output_size, input_size, dtype=torch.bfloat16, device="hpu")
weight_fp8 = weight_fp32.to(torch.float8_e4m3fn)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test, weight_fp32 is created with dtype=torch.bfloat16 (and then cast to FP8). Renaming the variable (or adjusting the dtype) would avoid confusion about what precision the tensor actually represents.

Suggested change
weight_fp32 = torch.randn(output_size, input_size, dtype=torch.bfloat16, device="hpu")
weight_fp8 = weight_fp32.to(torch.float8_e4m3fn)
weight_bf16 = torch.randn(output_size, input_size, dtype=torch.bfloat16, device="hpu")
weight_fp8 = weight_bf16.to(torch.float8_e4m3fn)

Copilot uses AI. Check for mistakes.
Comment on lines +467 to +473
# Execute layer with synthetic input
x = torch.randn(1, 4, input_size, dtype=torch.bfloat16, device="hpu")
out = oot_op.scheme.apply_weights(oot_op, x)
assert out.shape == (1, 4, output_size)
assert out.dtype == torch.bfloat16


Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new block-quantized Linear test only asserts shape and dtype. Given the PR changes affect dequantization / scale naming, it would be stronger to also validate numerical correctness (e.g., compare against a reference computed from a dequantized BF16 weight and a BF16 linear/matmul).

Suggested change
# Execute layer with synthetic input
x = torch.randn(1, 4, input_size, dtype=torch.bfloat16, device="hpu")
out = oot_op.scheme.apply_weights(oot_op, x)
assert out.shape == (1, 4, output_size)
assert out.dtype == torch.bfloat16
# Execute layer with deterministic input and validate numerical correctness
# against a BF16 reference computed from the dequantized FP8 weight.
x = torch.ones(1, 4, input_size, dtype=torch.bfloat16, device="hpu")
out = oot_op.scheme.apply_weights(oot_op, x)
assert out.shape == (1, 4, output_size)
assert out.dtype == torch.bfloat16
ref_weight = weight_fp8.to(torch.bfloat16)
ref_out = torch.matmul(x, ref_weight.transpose(0, 1))
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)

Copilot uses AI. Check for mistakes.
out = oot_op.runner.forward_impl(oot_op, hidden_states, router_logits, hidden_states)

assert out.shape == hidden_states.shape
assert out.dtype == torch.bfloat16
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new block-quantized MoE test currently verifies only output shape and dtype. To better cover the regression being fixed, consider adding a correctness assertion (or at least a stronger sanity check like finite outputs) against a reference implementation for the same weights/scales.

Suggested change
assert out.dtype == torch.bfloat16
assert out.dtype == torch.bfloat16
assert torch.isfinite(
out).all(), "block-quantized MoE output should not contain NaN or Inf values"

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

✅ CI Passed

All checks passed successfully against the following vllm commit:
2a69949bdadf0e8942b7a1619b229cb475beef20

@mgawarkiewicz-intel mgawarkiewicz-intel merged commit 2e3ef72 into vllm-project:releases/v0.19.0 Apr 20, 2026
71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants