Skip to content

FP8: Load model on-the-fly in vLLM#3717

Merged
danielhanchen merged 3 commits intounslothai:mainfrom
andrewor14:true-on-the-fly-fp8
Feb 10, 2026
Merged

FP8: Load model on-the-fly in vLLM#3717
danielhanchen merged 3 commits intounslothai:mainfrom
andrewor14:true-on-the-fly-fp8

Conversation

@andrewor14
Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 commented Dec 11, 2025

Summary: Existing support for load_in_fp8=True performs an offline quantization when loading the initial model. This is no longer necessary as of vllm==0.12.0 (after vllm-project/vllm#23014), where we can quantize the model on-the-fly when we load it:

llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)

Note: Needs unslothai/unsloth-zoo#380

Test Plan:
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @andrewor14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the FP8 quantization logic within the Unsloth framework to leverage new capabilities in vLLM (version 0.12.0 and above). Instead of performing an offline quantization step, models can now be quantized on-the-fly when loaded, simplifying the workflow and improving efficiency. The changes involve updating model loading functions to accept a load_in_fp8 parameter and integrating the new on-the-fly quantization mechanism.

Highlights

  • On-the-fly FP8 Quantization: Enables direct FP8 quantization during model loading for vLLM versions 0.12.0 and newer, eliminating the need for prior offline quantization.
  • Streamlined Model Loading: Simplifies the process of loading FP8 quantized models by integrating the quantization step directly into the LLM constructor.
  • Parameter Updates: Introduces load_in_fp8 parameter to from_pretrained methods in llama.py and vision.py and updates loader_utils.py to handle the new quantization flow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

andrewor14 added a commit to andrewor14/unsloth-zoo that referenced this pull request Dec 11, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the FP8 model loading to leverage vLLM's on-the-fly quantization for versions 0.12.0 and newer. This is a solid improvement, as it eliminates the need for offline quantization and the creation of a temporary model. The implementation across llama.py, loader.py, loader_utils.py, and vision.py is well-executed, correctly checking the vLLM version to conditionally skip the offline process. My feedback consists of a minor stylistic suggestion to improve the readability of boolean checks, making them more idiomatic to Python.

)

fp8_mode = None
if load_in_fp8 != False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For boolean checks, it's more idiomatic in Python to use the truthiness of the value directly rather than comparing with False. The load_in_fp8 parameter can be True, False, or a string like 'block'. Both True and non-empty strings are truthy, while False is falsy. Using if load_in_fp8: is more concise and readable, and achieves the same result.

Suggested change
if load_in_fp8 != False:
if load_in_fp8:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree with gemini here :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can change it, I just had it this way because I saw that's how Daniel wrote it in a few existing places

lower_model_name = model_name.lower()

assert load_in_fp8 in (True, False, "block")
if load_in_fp8 != False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comments, this check can be made more Pythonic. Instead of if load_in_fp8 != False:, you can use if load_in_fp8:. This leverages Python's truthiness evaluation and is generally considered better style for readability and conciseness.

Suggested change
if load_in_fp8 != False:
if load_in_fp8:

)

fp8_mode = None
if load_in_fp8 != False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check if load_in_fp8 != False: can be simplified to if load_in_fp8:. This is the more idiomatic and preferred way to check for truthiness in Python, improving code readability.

Suggested change
if load_in_fp8 != False:
if load_in_fp8:

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 241 to 246
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Enforce FP8/4bit mutual exclusion for vLLM >=0.12

When load_in_fp8 is true, get_model_name now returns the original name as soon as vLLM ≥ 0.12.0 (loader_utils.py lines 110-118), so the new_model_name is None branch here is never taken and _get_fp8_mode_and_check_settings no longer runs. With the default load_in_4bit=True, the code now proceeds to fast inference with both load_in_fp8 and use_bitsandbytes=load_in_4bit set, even though _get_fp8_mode_and_check_settings used to reject FP8 together with 4/8/16-bit loads. This yields conflicting quantization paths (fp8 on-the-fly plus bitsandbytes 4bit) and is likely to fail at runtime for users who simply enable load_in_fp8 without also disabling 4bit.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Copy Markdown
Contributor

@andrewor14 Oh thanks - would this be backwards compatible?

@andrewor14
Copy link
Copy Markdown
Contributor Author

Oh thanks - would this be backwards compatible?

Yeah, this only affects vllm >= 0.12.0. Behavior is the same as before for older versions. Just tested on 0.12.0 and 0.11.1

andrewor14 and others added 3 commits February 10, 2026 12:42
**Summary:** Existing support for `load_in_fp8=True` performs
an offline quantization when loading the initial model.
This is no longer necessary as of vllm==0.12.0 (after
vllm-project/vllm#23014), where we
can quantize the model on-the-fly when we load it:

```
llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)
```

**Note:** Needs unslothai/unsloth-zoo#380

**Test Plan:**
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423
The original implementation bypasses the FP8 mapper entirely for
vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B
that have pre-quantized FP8-Block/FP8 checkpoints would never use them.

This fixes the priority order:
1. Mapper has a pre-quantized model -> use it (always)
2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao
3. Mapper has no match + vllm < 0.12.0 -> offline quantization

Changes:
- loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups
- loader.py: Set load_in_fp8=False when mapper resolves to a
  pre-quantized model to prevent double quantization

Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code
produces results matching baseline (pre-quantized path preserved).
@danielhanchen
Copy link
Copy Markdown
Contributor

Rebased on latest main and pushed a fix for the mapper bypass issue.

Problem: The original implementation placed the vllm >= 0.12.0 early return at the top of the FP8 block in __get_model_name, which bypasses the mapper entirely. Models like unsloth/Llama-3.2-1B-Instruct and unsloth/Qwen3-8B that have pre-quantized FP8-Block/FP8 checkpoints would never use them -- they would always get on-the-fly quantization instead.

Fix: Moved the vllm version check to after all mapper lookups, so the priority is:

  1. Mapper has a pre-quantized model --> use it (always)
  2. Mapper has no match + vllm >= 0.12.0 --> on-the-fly FP8 via torchao
  3. Mapper has no match + vllm < 0.12.0 --> offline quantization (existing behavior)

Also added a guard in loader.py to set load_in_fp8 = False when the mapper resolves to a pre-quantized model, preventing double quantization (the pre-quantized model would otherwise also get on-the-fly quantization applied in llama.py/vision.py).

Testing (B200, vllm 0.15.1, 61-step SFT + LoRA):

Llama-3.2-1B-Instruct:

Method Load (s) Train (s) Peak Mem (GB) 1st / Last Loss 1st Grad-Norm
Baseline (main, pre-quant) 47.9 56.2 106.79 1.559 / 1.243 6.484
Original PR (on-the-fly, bypasses mapper) 55.8 61.8 106.79 1.560 / 1.215 1.165
Corrected (mapper first) 50.8 58.9 106.79 1.559 / 1.246 6.483

Qwen3-8B:

Method Load (s) Train (s) Peak Mem (GB) 1st / Last Loss 1st Grad-Norm
Baseline (main, pre-quant) 71.0 102.4 106.81 1.518 / 1.050 1.898
Original PR (on-the-fly, bypasses mapper) 75.7 123.0 107.01 1.519 / 1.039 0.767
Corrected (mapper first) 70.9 102.6 106.81 1.518 / 1.051 1.896

The corrected version matches baseline behavior for mapped models (losses and grad norms are nearly identical). The on-the-fly path is preserved for models not in the mapper.

Remaining note: is_fp8 detection in load_vllm uses "fp8" in model_name.lower(). On the on-the-fly path, the original model name won't contain "fp8", so is_fp8 = False -- this could affect gpu_memory_utilization estimation on smaller GPUs. Worth a follow-up fix.

@danielhanchen danielhanchen merged commit ea9e1fd into unslothai:main Feb 10, 2026
1 check passed
danielhanchen pushed a commit to unslothai/unsloth-zoo that referenced this pull request Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants