Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,13 @@ def vllm_dynamic_quant_supported(
pass


def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False):
def get_vllm_state_dict(
llm,
return_state_dict = False,
config = None,
is_vision_model = False,
load_in_fp8 = False,
):
# If the vllm state dict was quantized using torchao, we will run into
# the following error when calling ops like aten.t() in inference mode.
# This is a bug in PyTorch that affects all tensor subclasses.
Expand All @@ -854,7 +860,7 @@ def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision
#
# For now, we work around this issue by using torch.no_grad in this case.
# See https://github.com/pytorch/pytorch/issues/164872 for more details
if get_quant_type(config) == "torchao":
if get_quant_type(config) == "torchao" or load_in_fp8:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this strictly necessary? I don't remember this being needed for offline quant FP8.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was actually necessary for offline FP8 quant as well. It's just that now we don't require the config to explicitly say "torchao FP8" (it can be any bf16 checkpoint), and instead we handle the FP8 quantization dynamically through vllm, but we still need no_grad cause it's still using tensor subclasses

ctx_manager = torch.no_grad()
else:
ctx_manager = torch.inference_mode()
Comment on lines +863 to 866

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Treat fp8_mode loads as torchao when exporting state dict

When a model is quantized via the new fp8_mode path, the HF config still has no quantization_config, so get_quant_type(config) returns None and this branch keeps using torch.inference_mode. With torchao-quantized weights this raises PyTorch’s “Cannot set version_counter for inference tensor” error (the exact bug called out in the comment) when calling get_vllm_state_dict after a vLLM FP8 load. In practice, load_vllm(..., fp8_mode=…) followed by get_vllm_state_dict(llm, config=config) still crashes unless callers remember to pass the new load_in_fp8=True flag manually. The fp8 loader should automatically switch to torch.no_grad (or propagate the fp8 flag) so exporting a torchao FP8 model works by default.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -1706,6 +1712,30 @@ def _module_disallows_flashinfer(module) -> bool:
pass


def _get_torchao_fp8_config(fp8_mode: str):
"""
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
to be used for `load_in_fp8=True`.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions load_in_fp8=True, but this function is called when fp8_mode is set in load_vllm. The parameter load_in_fp8 exists in get_vllm_state_dict. This is a bit confusing. To improve clarity, consider updating the docstring to refer to fp8_mode.

Suggested change
to be used for `load_in_fp8=True`.
to be used when `fp8_mode` is set.

"""
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerBlock,
PerRow,
)

if fp8_mode == "row":
granularity = PerRow()
elif fp8_mode == "block":
granularity = (PerBlock([1, 128]), PerBlock([128, 128]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The block granularity for FP8 quantization is hardcoded. While this might be the intended configuration for now, it reduces flexibility. Consider defining this as a constant at the module level. This would make it easier to find and change if needed in the future, improving maintainability.

else:
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")

return Float8DynamicActivationFloat8WeightConfig(
granularity = granularity,
activation_value_lb = 1e-12,
)


def load_vllm(
model_name : str = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
config = None,
Expand All @@ -1731,6 +1761,7 @@ def load_vllm(
is_vision_model : bool = False,
return_args : bool = False, # Just return args
max_num_seqs : int = 256, # how many seqs to process in parallel. Default vLLM 256
fp8_mode : Optional[str] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new fp8_mode parameter enables on-the-fly FP8 quantization, which is a significant new feature. However, the test function _test_get_vllm_state_dict does not seem to be updated to exercise this new code path. It doesn't pass fp8_mode to load_vllm, nor does it seem to test the FP8 logic in get_vllm_state_dict.

It is highly recommended to add tests for this new functionality to ensure its correctness and prevent future regressions.

):
# All Unsloth Zoo code licensed under LGPLv3
# Create vLLM instance
Expand Down Expand Up @@ -2171,6 +2202,18 @@ def load_vllm(
engine_args["disable_cascade_attn"] = disable_cascade_attn
pass

# On-the-fly quantization is added in https://github.com/vllm-project/vllm/pull/23014
# This is only available in vllm >= 0.12.0. Older versions will do offline quantization
# by creating an FP8 checkpoint and loading it back in
if fp8_mode is not None and Version(vllm_version) >= Version("0.12.0"):
from torchao.core.config import config_to_dict
torchao_config = _get_torchao_fp8_config(fp8_mode)
hf_overrides = {
"quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)),
}
engine_args["quantization"] = "torchao"
Comment on lines +2210 to +2214

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge fp8_mode keeps bitsandbytes loader despite torchao quantization

In the fp8 path we set engine_args["quantization"] = "torchao" and supply torchao overrides, but we never reset the bitsandbytes defaults established earlier (load_format and memory estimates stay in 4-bit mode because use_bitsandbytes defaults to True). As a result, calling load_vllm(..., fp8_mode="row") with default arguments will attempt to load the checkpoint through the bitsandbytes loader (which expects 4-bit weights) even though we now intend to quantize to FP8 on the fly, leading to incompatible weight loading on standard BF16/FP16 checkpoints. The fp8 branch should disable bitsandbytes/load_format or force use_bitsandbytes=False so the torchao quantizer can consume the original weights.

Useful? React with 👍 / 👎.

engine_args["hf_overrides"] = hf_overrides
Comment on lines +2208 to +2215

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When fp8_mode is specified but the vLLM version is older than 0.12.0, the condition is false and the block is skipped. This means the fp8_mode is silently ignored, and the model will not be quantized with FP8 as requested. This can be misleading for the user.

Consider adding an else block to handle this case, for example by raising a NotImplementedError or at least logging a prominent warning that on-the-fly FP8 quantization is not supported and is being skipped.


good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys()
old_keys = list(engine_args.keys())
for key in old_keys:
Expand Down