FP8: Load model on-the-fly in vLLM#380
Conversation
Summary of ChangesHello @andrewor14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the model loading capabilities within vLLM by introducing on-the-fly FP8 quantization. This feature, available for vLLM versions 0.12.0 and above, leverages Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
**Summary:** Existing support for `load_in_fp8=True` performs an offline quantization when loading the initial model. This is no longer necessary as of vllm==0.12.0 (after vllm-project/vllm#23014), where we can quantize the model on-the-fly when we load it: ``` llm = LLM( ... hf_overrides={ "quantization_config_dict_str": json.dumps(torchao_config), }, ) ``` **Note:** Needs unslothai/unsloth-zoo#380 **Test Plan:** https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423
There was a problem hiding this comment.
Code Review
This pull request introduces support for on-the-fly FP8 quantization in vLLM. It adds a new fp8_mode parameter to load_vllm to configure the vLLM engine for torchao-based quantization. A helper function _get_torchao_fp8_config is added to generate the necessary configuration. The get_vllm_state_dict function is also updated to handle FP8 models correctly.
My review focuses on improving the clarity, maintainability, and correctness of the new implementation. I've pointed out a potential issue where the fp8_mode is silently ignored for older vLLM versions, which could be misleading. I've also suggested improvements to the documentation and handling of hardcoded values. Finally, I've noted that the new FP8 functionality does not appear to be covered by tests in this file, and I recommend adding them.
| if fp8_mode is not None and Version(vllm_version) >= Version("0.12.0"): | ||
| from torchao.core.config import config_to_dict | ||
| torchao_config = _get_torchao_fp8_config(fp8_mode) | ||
| hf_overrides = { | ||
| "quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)), | ||
| } | ||
| engine_args["quantization"] = "torchao" | ||
| engine_args["hf_overrides"] = hf_overrides |
There was a problem hiding this comment.
When fp8_mode is specified but the vLLM version is older than 0.12.0, the condition is false and the block is skipped. This means the fp8_mode is silently ignored, and the model will not be quantized with FP8 as requested. This can be misleading for the user.
Consider adding an else block to handle this case, for example by raising a NotImplementedError or at least logging a prominent warning that on-the-fly FP8 quantization is not supported and is being skipped.
| def _get_torchao_fp8_config(fp8_mode: str): | ||
| """ | ||
| Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig` | ||
| to be used for `load_in_fp8=True`. |
There was a problem hiding this comment.
The docstring mentions load_in_fp8=True, but this function is called when fp8_mode is set in load_vllm. The parameter load_in_fp8 exists in get_vllm_state_dict. This is a bit confusing. To improve clarity, consider updating the docstring to refer to fp8_mode.
| to be used for `load_in_fp8=True`. | |
| to be used when `fp8_mode` is set. |
| if fp8_mode == "row": | ||
| granularity = PerRow() | ||
| elif fp8_mode == "block": | ||
| granularity = (PerBlock([1, 128]), PerBlock([128, 128])) |
There was a problem hiding this comment.
| is_vision_model : bool = False, | ||
| return_args : bool = False, # Just return args | ||
| max_num_seqs : int = 256, # how many seqs to process in parallel. Default vLLM 256 | ||
| fp8_mode : Optional[str] = None, |
There was a problem hiding this comment.
The new fp8_mode parameter enables on-the-fly FP8 quantization, which is a significant new feature. However, the test function _test_get_vllm_state_dict does not seem to be updated to exercise this new code path. It doesn't pass fp8_mode to load_vllm, nor does it seem to test the FP8 logic in get_vllm_state_dict.
It is highly recommended to add tests for this new functionality to ensure its correctness and prevent future regressions.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if get_quant_type(config) == "torchao" or load_in_fp8: | ||
| ctx_manager = torch.no_grad() | ||
| else: | ||
| ctx_manager = torch.inference_mode() |
There was a problem hiding this comment.
Treat fp8_mode loads as torchao when exporting state dict
When a model is quantized via the new fp8_mode path, the HF config still has no quantization_config, so get_quant_type(config) returns None and this branch keeps using torch.inference_mode. With torchao-quantized weights this raises PyTorch’s “Cannot set version_counter for inference tensor” error (the exact bug called out in the comment) when calling get_vllm_state_dict after a vLLM FP8 load. In practice, load_vllm(..., fp8_mode=…) followed by get_vllm_state_dict(llm, config=config) still crashes unless callers remember to pass the new load_in_fp8=True flag manually. The fp8 loader should automatically switch to torch.no_grad (or propagate the fp8 flag) so exporting a torchao FP8 model works by default.
Useful? React with 👍 / 👎.
| torchao_config = _get_torchao_fp8_config(fp8_mode) | ||
| hf_overrides = { | ||
| "quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)), | ||
| } | ||
| engine_args["quantization"] = "torchao" |
There was a problem hiding this comment.
fp8_mode keeps bitsandbytes loader despite torchao quantization
In the fp8 path we set engine_args["quantization"] = "torchao" and supply torchao overrides, but we never reset the bitsandbytes defaults established earlier (load_format and memory estimates stay in 4-bit mode because use_bitsandbytes defaults to True). As a result, calling load_vllm(..., fp8_mode="row") with default arguments will attempt to load the checkpoint through the bitsandbytes loader (which expects 4-bit weights) even though we now intend to quantize to FP8 on the fly, leading to incompatible weight loading on standard BF16/FP16 checkpoints. The fp8 branch should disable bitsandbytes/load_format or force use_bitsandbytes=False so the torchao quantizer can consume the original weights.
Useful? React with 👍 / 👎.
| # For now, we work around this issue by using torch.no_grad in this case. | ||
| # See https://github.com/pytorch/pytorch/issues/164872 for more details | ||
| if get_quant_type(config) == "torchao": | ||
| if get_quant_type(config) == "torchao" or load_in_fp8: |
There was a problem hiding this comment.
Is this strictly necessary? I don't remember this being needed for offline quant FP8.
There was a problem hiding this comment.
Yeah this was actually necessary for offline FP8 quant as well. It's just that now we don't require the config to explicitly say "torchao FP8" (it can be any bf16 checkpoint), and instead we handle the FP8 quantization dynamically through vllm, but we still need no_grad cause it's still using tensor subclasses
686d4fd to
1a14140
Compare
|
Rebased on latest main. The zoo changes look good -- The mapper bypass fix was applied on the unsloth side (PR #3717). |
* FP8: Load model on-the-fly in vLLM **Summary:** Existing support for `load_in_fp8=True` performs an offline quantization when loading the initial model. This is no longer necessary as of vllm==0.12.0 (after vllm-project/vllm#23014), where we can quantize the model on-the-fly when we load it: ``` llm = LLM( ... hf_overrides={ "quantization_config_dict_str": json.dumps(torchao_config), }, ) ``` **Note:** Needs unslothai/unsloth-zoo#380 **Test Plan:** https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly The original implementation bypasses the FP8 mapper entirely for vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B that have pre-quantized FP8-Block/FP8 checkpoints would never use them. This fixes the priority order: 1. Mapper has a pre-quantized model -> use it (always) 2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao 3. Mapper has no match + vllm < 0.12.0 -> offline quantization Changes: - loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups - loader.py: Set load_in_fp8=False when mapper resolves to a pre-quantized model to prevent double quantization Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code produces results matching baseline (pre-quantized path preserved). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
* FP8: Load model on-the-fly in vLLM **Summary:** Existing support for `load_in_fp8=True` performs an offline quantization when loading the initial model. This is no longer necessary as of vllm==0.12.0 (after vllm-project/vllm#23014), where we can quantize the model on-the-fly when we load it: ``` llm = LLM( ... hf_overrides={ "quantization_config_dict_str": json.dumps(torchao_config), }, ) ``` **Note:** Needs unslothai/unsloth-zoo#380 **Test Plan:** https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly The original implementation bypasses the FP8 mapper entirely for vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B that have pre-quantized FP8-Block/FP8 checkpoints would never use them. This fixes the priority order: 1. Mapper has a pre-quantized model -> use it (always) 2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao 3. Mapper has no match + vllm < 0.12.0 -> offline quantization Changes: - loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups - loader.py: Set load_in_fp8=False when mapper resolves to a pre-quantized model to prevent double quantization Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code produces results matching baseline (pre-quantized path preserved). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
* FP8: Load model on-the-fly in vLLM **Summary:** Existing support for `load_in_fp8=True` performs an offline quantization when loading the initial model. This is no longer necessary as of vllm==0.12.0 (after vllm-project/vllm#23014), where we can quantize the model on-the-fly when we load it: ``` llm = LLM( ... hf_overrides={ "quantization_config_dict_str": json.dumps(torchao_config), }, ) ``` **Note:** Needs unslothai/unsloth-zoo#380 **Test Plan:** https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix on-the-fly FP8: always check mapper first, fallback to on-the-fly The original implementation bypasses the FP8 mapper entirely for vllm >= 0.12.0, meaning models like Llama-3.2-1B-Instruct and Qwen3-8B that have pre-quantized FP8-Block/FP8 checkpoints would never use them. This fixes the priority order: 1. Mapper has a pre-quantized model -> use it (always) 2. Mapper has no match + vllm >= 0.12.0 -> on-the-fly FP8 via torchao 3. Mapper has no match + vllm < 0.12.0 -> offline quantization Changes: - loader_utils.py: Move vllm >= 0.12.0 check after mapper lookups - loader.py: Set load_in_fp8=False when mapper resolves to a pre-quantized model to prevent double quantization Tested on B200 with Llama-3.2-1B-Instruct and Qwen3-8B. Corrected code produces results matching baseline (pre-quantized path preserved). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
Supports unslothai/unsloth#3717