Skip to content

[Bugfix] Fix FP8 Bias Loading#41424

Merged
Isotr0py merged 4 commits intovllm-project:mainfrom
alex-jw-brooks:fix_bias_loads
May 3, 2026
Merged

[Bugfix] Fix FP8 Bias Loading#41424
Isotr0py merged 4 commits intovllm-project:mainfrom
alex-jw-brooks:fix_bias_loads

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

Purpose

Fixes the underlying cause of #41284

The issue is that when layers have bias=True, we do the following:

  • Initialize the weights on meta device + wrap its own weight loader, allocate the bias (not on meta device)
  • The params are generally yielded alphabetically when we are loading them. This means:
    • First, we load the bias normally
    • Then, we load the weight, which needs to materialize the weight meta tensors, which means replacing it with a new on device tensor with torch.empty_strided

The materialization currently does this to everything, including the bias, even though it should only do it to weights. This corrupts the bias values, which creates NaNs in forward() and ultimately produces garbage values.

The handling for NaNs is also why things worked for granite speech in 0.17 with fp8, but not in 0.20. I think the native forward doesn't handle NaNs in the same way, which is why the values diverge, will open a separate PR to discuss.

Test Plan

Added an explicit test - you can also verify the fix with a minimal fp8 example with granite speech.

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

model_id = "ibm-granite/granite-speech-4.1-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def get_prompt(question: str, has_audio: bool):
    """Build the input prompt to send to vLLM."""
    if has_audio:
        question = f"<|audio|>{question}"
    chat = [
        {
            "role": "user",
            "content": question
        }
    ]
    return tokenizer.apply_chat_template(chat, tokenize=False)

model = LLM(
    model=model_id,
    max_model_len=2048, # This may be needed for lower resource devices.
    limit_mm_per_prompt={"audio": 1},
    quantization="fp8",
)

question = "can you transcribe the speech into a written format?"
prompt_with_audio = get_prompt(
    question=question,
    has_audio=True,
)
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate

inputs = {
    "prompt": prompt_with_audio,
    "multi_modal_data": {
        "audio": audio,
    }
}

outputs = model.generate(
    inputs,
    sampling_params=SamplingParams(
        temperature=0.2,
        max_tokens=64,
    ),
)
print(f"Audio Example - Question: {question}")
print(f"Generated text: {outputs[0].outputs[0].text}")

Test Result

On main:

Generated text: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

After fix:

Generated text: the first words i spoke in the original phonograph a little piece of practical poetry mary had a little lamb its fleece was white as snow and everywhere that mary went the lamb was sure to go

CC @DarkLight1337 @robertgshaw2-redhat @lokashrinav

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks alex-jw-brooks requested a review from 22quinn as a code owner April 30, 2026 23:29
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the bug Something isn't working label Apr 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the materialize_layer function to ensure that only meta tensors are materialized, preventing the overwriting of already initialized non-meta tensors. A new test case, test_materialize_layer_preserves_non_meta_tensors, has been added to verify this logic. I have no feedback to provide.

@Isotr0py Isotr0py enabled auto-merge (squash) May 2, 2026 04:43
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2026
@Isotr0py Isotr0py merged commit db9a84e into vllm-project:main May 3, 2026
51 checks passed
joa-stdn pushed a commit to joa-stdn/vllm that referenced this pull request May 4, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Joachim Studnia <joachim@mistral.ai>
chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request May 6, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request May 7, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants