Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_get_inference_mode_context_manager,
_prepare_model_for_qat,
)
from .loader_utils import _get_fp8_mode_and_check_settings
from ..utils.packing import (
get_packed_info_from_kwargs,
mask_packed_sequence_boundaries,
Expand Down Expand Up @@ -2192,6 +2193,7 @@ def from_pretrained(
unsloth_vllm_standby = False,
num_labels = None,
qat_scheme = None,
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
**kwargs,
):
os.environ["UNSLOTH_USE_NEW_MODEL"] = "0"
Expand Down Expand Up @@ -2435,6 +2437,13 @@ def from_pretrained(
generate_batches,
)

fp8_mode = None
if load_in_fp8 != False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For boolean checks, it's more idiomatic in Python to use the truthiness of the value directly rather than comparing with False. The load_in_fp8 parameter can be True, False, or a string like 'block'. Both True and non-empty strings are truthy, while False is falsy. Using if load_in_fp8: is more concise and readable, and achieves the same result.

Suggested change
if load_in_fp8 != False:
if load_in_fp8:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree with gemini here :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can change it, I just had it this way because I saw that's how Daniel wrote it in a few existing places

fp8_mode = _get_fp8_mode_and_check_settings(
load_in_fp8,
fast_inference,
)

allowed_args = inspect.getfullargspec(load_vllm).args
load_vllm_kwargs = dict(
model_name = model_name,
Expand All @@ -2448,6 +2457,7 @@ def from_pretrained(
disable_log_stats = disable_log_stats,
use_bitsandbytes = load_in_4bit,
unsloth_vllm_standby = unsloth_vllm_standby,
fp8_mode = fp8_mode,
)
for allowed_arg in allowed_args:
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
Expand All @@ -2458,7 +2468,11 @@ def from_pretrained(
llm = load_vllm(**load_vllm_kwargs)

# Convert to HF format
_, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
_, quant_state_dict = get_vllm_state_dict(
llm,
config = model_config,
load_in_fp8 = load_in_fp8,
)
model = convert_vllm_to_huggingface(
quant_state_dict, model_config, dtype, bnb_config
)
Expand Down
12 changes: 10 additions & 2 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,15 @@ def from_pretrained(
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
else:
Comment on lines 290 to 295
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Enforce FP8/4bit mutual exclusion for vLLM >=0.12

When load_in_fp8 is true, get_model_name now returns the original name as soon as vLLM ≥ 0.12.0 (loader_utils.py lines 110-118), so the new_model_name is None branch here is never taken and _get_fp8_mode_and_check_settings no longer runs. With the default load_in_4bit=True, the code now proceeds to fast inference with both load_in_fp8 and use_bitsandbytes=load_in_4bit set, even though _get_fp8_mode_and_check_settings used to reject FP8 together with 4/8/16-bit loads. This yields conflicting quantization paths (fp8 on-the-fly plus bitsandbytes 4bit) and is likely to fail at runtime for users who simply enable load_in_fp8 without also disabling 4bit.

Useful? React with 👍 / 👎.

assert new_model_name is not None
model_name = new_model_name
# If mapper resolved to a pre-quantized FP8 model, disable
# on-the-fly quantization to avoid double quantization
if load_in_fp8 != False and new_model_name != old_model_name:
load_in_fp8 = False

# Check if pre-quantized models are allowed
# For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
Expand Down Expand Up @@ -615,6 +618,7 @@ def from_pretrained(
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
Expand Down Expand Up @@ -894,12 +898,15 @@ def from_pretrained(
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
else:
assert new_model_name is not None
model_name = new_model_name
# If mapper resolved to a pre-quantized FP8 model, disable
# on-the-fly quantization to avoid double quantization
if load_in_fp8 != False and new_model_name != old_model_name:
load_in_fp8 = False

# Check if pre-quantized models are allowed
# For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
Expand Down Expand Up @@ -1311,6 +1318,7 @@ def from_pretrained(
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
Expand Down
49 changes: 15 additions & 34 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from transformers import __version__ as transformers_version
from unsloth.models._utils import TorchAOConfig
from unsloth_zoo.utils import Version
from unsloth_zoo.vllm_utils import _get_torchao_fp8_config
import gc

transformers_version = Version(transformers_version)
Expand Down Expand Up @@ -117,6 +118,15 @@ def __get_model_name(
else:
if lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER:
return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name]
# Mapper didn't find a pre-quantized model.
# For vllm >= 0.12.0, we can quantize the model to FP8 on the fly,
# so just return the original model name. Older vllm versions will
# fall through to offline quantization via _offline_quantize_to_fp8.
if importlib.util.find_spec("vllm") is not None:
import vllm

if Version(vllm.__version__) >= Version("0.12.0"):
return model_name
return None

elif not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
Expand Down Expand Up @@ -235,38 +245,12 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
return new_model_name if new_model_name is not None else model_name


def _get_torchao_fp8_config(fp8_mode: str):
"""
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
to be used for `load_in_fp8=True`.
"""
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerBlock,
PerRow,
)

if fp8_mode == "row":
granularity = PerRow()
elif fp8_mode == "block":
granularity = (PerBlock([1, 128]), PerBlock([128, 128]))
else:
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")

return Float8DynamicActivationFloat8WeightConfig(
granularity = granularity,
activation_value_lb = 1e-12,
)


def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
"""
Quantizes the model to fp8 using torchao and saving the quantized model to a
temporary location. Return the path to the quantized model.

Note: Once on-the-fly quantization is added in vllm in
https://github.com/vllm-project/vllm/pull/26327, we should
dynamically quantize the model there instead:
Note: For vllm >= 0.12.0, we should dynamically quantize the model in vllm instead:

llm = LLM(
...
Expand Down Expand Up @@ -333,11 +317,10 @@ def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
def _get_fp8_mode_and_check_settings(
load_in_fp8: Union[bool, str],
fast_inference: bool,
full_finetuning: bool,
load_in_4bit: bool,
load_in_8bit: bool,
load_in_16bit: bool,
use_exact_model_name: bool,
full_finetuning: bool = False,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
load_in_16bit: bool = False,
) -> str:
"""
Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings
Expand Down Expand Up @@ -373,8 +356,6 @@ def _get_fp8_mode_and_check_settings(
raise ValueError(
"Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`",
)
if use_exact_model_name:
raise ValueError("Unsloth: `load_in_fp8` requires `use_exact_model_name=False`")

# Check if this is Hopper or above
if not (
Expand Down
15 changes: 15 additions & 0 deletions unsloth/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ._utils import __version__, importlib_version, _prepare_model_for_qat
from ._utils import *
from .loader_utils import _get_fp8_mode_and_check_settings
from ..save import patch_saving_functions
from ..models.loader_utils import is_distributed
from unsloth_zoo.gradient_checkpointing import (
Expand Down Expand Up @@ -433,6 +434,7 @@ def from_pretrained(
max_lora_rank = 64,
disable_log_stats = False,
unsloth_vllm_standby = False,
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
**kwargs,
):
if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1":
Expand Down Expand Up @@ -838,6 +840,17 @@ def from_pretrained(
model_name, model_config
)

fp8_mode = None
if load_in_fp8 != False:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check if load_in_fp8 != False: can be simplified to if load_in_fp8:. This is the more idiomatic and preferred way to check for truthiness in Python, improving code readability.

Suggested change
if load_in_fp8 != False:
if load_in_fp8:

fp8_mode = _get_fp8_mode_and_check_settings(
load_in_fp8,
fast_inference,
full_finetuning,
load_in_4bit,
load_in_8bit,
load_in_16bit,
)

allowed_args = inspect.getfullargspec(load_vllm).args
load_vllm_kwargs = dict(
model_name = model_name,
Expand All @@ -852,6 +865,7 @@ def from_pretrained(
use_bitsandbytes = load_in_4bit,
unsloth_vllm_standby = unsloth_vllm_standby,
is_vision_model = is_vlm,
fp8_mode = fp8_mode,
)
for allowed_arg in allowed_args:
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
Expand All @@ -865,6 +879,7 @@ def from_pretrained(
llm,
config = model_config,
is_vision_model = is_vlm,
load_in_fp8 = load_in_fp8,
)
model = convert_vllm_to_huggingface(
quant_state_dict,
Expand Down