Skip to content

FP8: Load model on-the-fly in vLLM#380

Merged
danielhanchen merged 1 commit into
unslothai:mainfrom
andrewor14:true-on-the-fly-fp8
Feb 10, 2026
Merged

FP8: Load model on-the-fly in vLLM#380
danielhanchen merged 1 commit into
unslothai:mainfrom
andrewor14:true-on-the-fly-fp8

Conversation

@andrewor14
Copy link
Copy Markdown
Contributor

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @andrewor14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the model loading capabilities within vLLM by introducing on-the-fly FP8 quantization. This feature, available for vLLM versions 0.12.0 and above, leverages torchao to configure and apply FP8 precision, which can lead to more efficient memory usage and faster inference. The changes also include a necessary adjustment to the context manager used during state dictionary retrieval to ensure compatibility and stability with FP8 operations.

Highlights

  • FP8 Quantization Support: Introduced support for on-the-fly FP8 quantization when loading models into vLLM, specifically for vLLM versions 0.12.0 and newer.
  • Torchao Integration: Integrated torchao for configuring the FP8 quantization process, allowing users to specify the desired FP8 mode ('row' or 'block').
  • Context Manager Adjustment: Modified the get_vllm_state_dict function to use torch.no_grad() when FP8 quantization is active, addressing a potential PyTorch bug related to tensor subclasses.
  • New Helper Function: Added a new private helper function _get_torchao_fp8_config to streamline the creation of torchao FP8 quantization configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

andrewor14 added a commit to andrewor14/unsloth that referenced this pull request Dec 11, 2025
**Summary:** Existing support for `load_in_fp8=True` performs
an offline quantization when loading the initial model.
This is no longer necessary as of vllm==0.12.0 (after
vllm-project/vllm#23014), where we
can quantize the model on-the-fly when we load it:

```
llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)
```

**Note:** Needs unslothai/unsloth-zoo#380

**Test Plan:**
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for on-the-fly FP8 quantization in vLLM. It adds a new fp8_mode parameter to load_vllm to configure the vLLM engine for torchao-based quantization. A helper function _get_torchao_fp8_config is added to generate the necessary configuration. The get_vllm_state_dict function is also updated to handle FP8 models correctly.

My review focuses on improving the clarity, maintainability, and correctness of the new implementation. I've pointed out a potential issue where the fp8_mode is silently ignored for older vLLM versions, which could be misleading. I've also suggested improvements to the documentation and handling of hardcoded values. Finally, I've noted that the new FP8 functionality does not appear to be covered by tests in this file, and I recommend adding them.

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +2103 to +2110
if fp8_mode is not None and Version(vllm_version) >= Version("0.12.0"):
from torchao.core.config import config_to_dict
torchao_config = _get_torchao_fp8_config(fp8_mode)
hf_overrides = {
"quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)),
}
engine_args["quantization"] = "torchao"
engine_args["hf_overrides"] = hf_overrides
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When fp8_mode is specified but the vLLM version is older than 0.12.0, the condition is false and the block is skipped. This means the fp8_mode is silently ignored, and the model will not be quantized with FP8 as requested. This can be misleading for the user.

Consider adding an else block to handle this case, for example by raising a NotImplementedError or at least logging a prominent warning that on-the-fly FP8 quantization is not supported and is being skipped.

Comment thread unsloth_zoo/vllm_utils.py
def _get_torchao_fp8_config(fp8_mode: str):
"""
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
to be used for `load_in_fp8=True`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions load_in_fp8=True, but this function is called when fp8_mode is set in load_vllm. The parameter load_in_fp8 exists in get_vllm_state_dict. This is a bit confusing. To improve clarity, consider updating the docstring to refer to fp8_mode.

Suggested change
to be used for `load_in_fp8=True`.
to be used when `fp8_mode` is set.

Comment thread unsloth_zoo/vllm_utils.py
if fp8_mode == "row":
granularity = PerRow()
elif fp8_mode == "block":
granularity = (PerBlock([1, 128]), PerBlock([128, 128]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The block granularity for FP8 quantization is hardcoded. While this might be the intended configuration for now, it reduces flexibility. Consider defining this as a constant at the module level. This would make it easier to find and change if needed in the future, improving maintainability.

Comment thread unsloth_zoo/vllm_utils.py
is_vision_model : bool = False,
return_args : bool = False, # Just return args
max_num_seqs : int = 256, # how many seqs to process in parallel. Default vLLM 256
fp8_mode : Optional[str] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new fp8_mode parameter enables on-the-fly FP8 quantization, which is a significant new feature. However, the test function _test_get_vllm_state_dict does not seem to be updated to exercise this new code path. It doesn't pass fp8_mode to load_vllm, nor does it seem to test the FP8 logic in get_vllm_state_dict.

It is highly recommended to add tests for this new functionality to ensure its correctness and prevent future regressions.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +831 to 834
if get_quant_type(config) == "torchao" or load_in_fp8:
ctx_manager = torch.no_grad()
else:
ctx_manager = torch.inference_mode()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Treat fp8_mode loads as torchao when exporting state dict

When a model is quantized via the new fp8_mode path, the HF config still has no quantization_config, so get_quant_type(config) returns None and this branch keeps using torch.inference_mode. With torchao-quantized weights this raises PyTorch’s “Cannot set version_counter for inference tensor” error (the exact bug called out in the comment) when calling get_vllm_state_dict after a vLLM FP8 load. In practice, load_vllm(..., fp8_mode=…) followed by get_vllm_state_dict(llm, config=config) still crashes unless callers remember to pass the new load_in_fp8=True flag manually. The fp8 loader should automatically switch to torch.no_grad (or propagate the fp8 flag) so exporting a torchao FP8 model works by default.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +2105 to +2109
torchao_config = _get_torchao_fp8_config(fp8_mode)
hf_overrides = {
"quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)),
}
engine_args["quantization"] = "torchao"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge fp8_mode keeps bitsandbytes loader despite torchao quantization

In the fp8 path we set engine_args["quantization"] = "torchao" and supply torchao overrides, but we never reset the bitsandbytes defaults established earlier (load_format and memory estimates stay in 4-bit mode because use_bitsandbytes defaults to True). As a result, calling load_vllm(..., fp8_mode="row") with default arguments will attempt to load the checkpoint through the bitsandbytes loader (which expects 4-bit weights) even though we now intend to quantize to FP8 on the fly, leading to incompatible weight loading on standard BF16/FP16 checkpoints. The fp8 branch should disable bitsandbytes/load_format or force use_bitsandbytes=False so the torchao quantizer can consume the original weights.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/vllm_utils.py
# For now, we work around this issue by using torch.no_grad in this case.
# See https://github.com/pytorch/pytorch/issues/164872 for more details
if get_quant_type(config) == "torchao":
if get_quant_type(config) == "torchao" or load_in_fp8:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this strictly necessary? I don't remember this being needed for offline quant FP8.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was actually necessary for offline FP8 quant as well. It's just that now we don't require the config to explicitly say "torchao FP8" (it can be any bf16 checkpoint), and instead we handle the FP8 quantization dynamically through vllm, but we still need no_grad cause it's still using tensor subclasses

@danielhanchen
Copy link
Copy Markdown
Member

Rebased on latest main. The zoo changes look good -- _get_torchao_fp8_config, fp8_mode param in load_vllm, load_in_fp8 param in get_vllm_state_dict, and the torch.no_grad() fix are all correct.

The mapper bypass fix was applied on the unsloth side (PR #3717).

danielhanchen added a commit to unslothai/unsloth that referenced this pull request Feb 10, 2026
* FP8: Load model on-the-fly in vLLM

**Summary:** Existing support for `load_in_fp8=True` performs
an offline quantization when loading the initial model.
This is no longer necessary as of vllm==0.12.0 (after
vllm-project/vllm#23014), where we
can quantize the model on-the-fly when we load it:

```
llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)
```

**Note:** Needs unslothai/unsloth-zoo#380

**Test Plan:**
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly

The original implementation bypasses the FP8 mapper entirely for
vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B
that have pre-quantized FP8-Block/FP8 checkpoints would never use them.

This fixes the priority order:
1. Mapper has a pre-quantized model -> use it (always)
2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao
3. Mapper has no match + vllm < 0.12.0 -> offline quantization

Changes:
- loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups
- loader.py: Set load_in_fp8=False when mapper resolves to a
  pre-quantized model to prevent double quantization

Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code
produces results matching baseline (pre-quantized path preserved).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
@danielhanchen danielhanchen merged commit 95e37ea into unslothai:main Feb 10, 2026
29583855 pushed a commit to 29583855/unsloth that referenced this pull request Mar 20, 2026
* FP8: Load model on-the-fly in vLLM

**Summary:** Existing support for `load_in_fp8=True` performs
an offline quantization when loading the initial model.
This is no longer necessary as of vllm==0.12.0 (after
vllm-project/vllm#23014), where we
can quantize the model on-the-fly when we load it:

```
llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)
```

**Note:** Needs unslothai/unsloth-zoo#380

**Test Plan:**
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly

The original implementation bypasses the FP8 mapper entirely for
vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B
that have pre-quantized FP8-Block/FP8 checkpoints would never use them.

This fixes the priority order:
1. Mapper has a pre-quantized model -> use it (always)
2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao
3. Mapper has no match + vllm < 0.12.0 -> offline quantization

Changes:
- loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups
- loader.py: Set load_in_fp8=False when mapper resolves to a
  pre-quantized model to prevent double quantization

Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code
produces results matching baseline (pre-quantized path preserved).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
* FP8: Load model on-the-fly in vLLM

**Summary:** Existing support for `load_in_fp8=True` performs
an offline quantization when loading the initial model.
This is no longer necessary as of vllm==0.12.0 (after
vllm-project/vllm#23014), where we
can quantize the model on-the-fly when we load it:

```
llm = LLM(
  ...
  hf_overrides={
    "quantization_config_dict_str": json.dumps(torchao_config),
  },
)
```

**Note:** Needs unslothai/unsloth-zoo#380

**Test Plan:**
https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly

The original implementation bypasses the FP8 mapper entirely for
vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B
that have pre-quantized FP8-Block/FP8 checkpoints would never use them.

This fixes the priority order:
1. Mapper has a pre-quantized model -> use it (always)
2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao
3. Mapper has no match + vllm < 0.12.0 -> offline quantization

Changes:
- loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups
- loader.py: Set load_in_fp8=False when mapper resolves to a
  pre-quantized model to prevent double quantization

Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code
produces results matching baseline (pre-quantized path preserved).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants