diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 89c25e7316..36856ee23d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -26,6 +26,7 @@ _get_inference_mode_context_manager, _prepare_model_for_qat, ) +from .loader_utils import _get_fp8_mode_and_check_settings from ..utils.packing import ( get_packed_info_from_kwargs, mask_packed_sequence_boundaries, @@ -2192,6 +2193,7 @@ def from_pretrained( unsloth_vllm_standby = False, num_labels = None, qat_scheme = None, + load_in_fp8 = False, # fp8 LoRA (True, False, 'block') **kwargs, ): os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" @@ -2435,6 +2437,13 @@ def from_pretrained( generate_batches, ) + fp8_mode = None + if load_in_fp8 != False: + fp8_mode = _get_fp8_mode_and_check_settings( + load_in_fp8, + fast_inference, + ) + allowed_args = inspect.getfullargspec(load_vllm).args load_vllm_kwargs = dict( model_name = model_name, @@ -2448,6 +2457,7 @@ def from_pretrained( disable_log_stats = disable_log_stats, use_bitsandbytes = load_in_4bit, unsloth_vllm_standby = unsloth_vllm_standby, + fp8_mode = fp8_mode, ) for allowed_arg in allowed_args: if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: @@ -2458,7 +2468,11 @@ def from_pretrained( llm = load_vllm(**load_vllm_kwargs) # Convert to HF format - _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) + _, quant_state_dict = get_vllm_state_dict( + llm, + config = model_config, + load_in_fp8 = load_in_fp8, + ) model = convert_vllm_to_huggingface( quant_state_dict, model_config, dtype, bnb_config ) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 9d2d3d9b06..4054f1b7f5 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -290,12 +290,15 @@ def from_pretrained( load_in_4bit, load_in_8bit, load_in_16bit, - use_exact_model_name, ) model_name = _offline_quantize_to_fp8(model_name, fp8_mode) else: assert new_model_name is not None model_name = new_model_name + # If mapper resolved to a pre-quantized FP8 model, disable + # on-the-fly quantization to avoid double quantization + if load_in_fp8 != False and new_model_name != old_model_name: + load_in_fp8 = False # Check if pre-quantized models are allowed # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 @@ -615,6 +618,7 @@ def from_pretrained( random_state = random_state, max_lora_rank = max_lora_rank, disable_log_stats = disable_log_stats, + load_in_fp8 = load_in_fp8, *args, **kwargs, ) @@ -894,12 +898,15 @@ def from_pretrained( load_in_4bit, load_in_8bit, load_in_16bit, - use_exact_model_name, ) model_name = _offline_quantize_to_fp8(model_name, fp8_mode) else: assert new_model_name is not None model_name = new_model_name + # If mapper resolved to a pre-quantized FP8 model, disable + # on-the-fly quantization to avoid double quantization + if load_in_fp8 != False and new_model_name != old_model_name: + load_in_fp8 = False # Check if pre-quantized models are allowed # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 @@ -1311,6 +1318,7 @@ def from_pretrained( random_state = random_state, max_lora_rank = max_lora_rank, disable_log_stats = disable_log_stats, + load_in_fp8 = load_in_fp8, *args, **kwargs, ) diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py index 1e5533c25c..01d221c725 100644 --- a/unsloth/models/loader_utils.py +++ b/unsloth/models/loader_utils.py @@ -31,6 +31,7 @@ from transformers import __version__ as transformers_version from unsloth.models._utils import TorchAOConfig from unsloth_zoo.utils import Version +from unsloth_zoo.vllm_utils import _get_torchao_fp8_config import gc transformers_version = Version(transformers_version) @@ -117,6 +118,15 @@ def __get_model_name( else: if lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER: return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name] + # Mapper didn't find a pre-quantized model. + # For vllm >= 0.12.0, we can quantize the model to FP8 on the fly, + # so just return the original model name. Older vllm versions will + # fall through to offline quantization via _offline_quantize_to_fp8. + if importlib.util.find_spec("vllm") is not None: + import vllm + + if Version(vllm.__version__) >= Version("0.12.0"): + return model_name return None elif not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER: @@ -235,38 +245,12 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False): return new_model_name if new_model_name is not None else model_name -def _get_torchao_fp8_config(fp8_mode: str): - """ - Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig` - to be used for `load_in_fp8=True`. - """ - from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, - PerBlock, - PerRow, - ) - - if fp8_mode == "row": - granularity = PerRow() - elif fp8_mode == "block": - granularity = (PerBlock([1, 128]), PerBlock([128, 128])) - else: - raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'") - - return Float8DynamicActivationFloat8WeightConfig( - granularity = granularity, - activation_value_lb = 1e-12, - ) - - def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str: """ Quantizes the model to fp8 using torchao and saving the quantized model to a temporary location. Return the path to the quantized model. - Note: Once on-the-fly quantization is added in vllm in - https://github.com/vllm-project/vllm/pull/26327, we should - dynamically quantize the model there instead: + Note: For vllm >= 0.12.0, we should dynamically quantize the model in vllm instead: llm = LLM( ... @@ -333,11 +317,10 @@ def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str): def _get_fp8_mode_and_check_settings( load_in_fp8: Union[bool, str], fast_inference: bool, - full_finetuning: bool, - load_in_4bit: bool, - load_in_8bit: bool, - load_in_16bit: bool, - use_exact_model_name: bool, + full_finetuning: bool = False, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + load_in_16bit: bool = False, ) -> str: """ Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings @@ -373,8 +356,6 @@ def _get_fp8_mode_and_check_settings( raise ValueError( "Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`", ) - if use_exact_model_name: - raise ValueError("Unsloth: `load_in_fp8` requires `use_exact_model_name=False`") # Check if this is Hopper or above if not ( diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 735c28d917..3e6dc8ac5f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -31,6 +31,7 @@ ) from ._utils import __version__, importlib_version, _prepare_model_for_qat from ._utils import * +from .loader_utils import _get_fp8_mode_and_check_settings from ..save import patch_saving_functions from ..models.loader_utils import is_distributed from unsloth_zoo.gradient_checkpointing import ( @@ -433,6 +434,7 @@ def from_pretrained( max_lora_rank = 64, disable_log_stats = False, unsloth_vllm_standby = False, + load_in_fp8 = False, # fp8 LoRA (True, False, 'block') **kwargs, ): if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1": @@ -838,6 +840,17 @@ def from_pretrained( model_name, model_config ) + fp8_mode = None + if load_in_fp8 != False: + fp8_mode = _get_fp8_mode_and_check_settings( + load_in_fp8, + fast_inference, + full_finetuning, + load_in_4bit, + load_in_8bit, + load_in_16bit, + ) + allowed_args = inspect.getfullargspec(load_vllm).args load_vllm_kwargs = dict( model_name = model_name, @@ -852,6 +865,7 @@ def from_pretrained( use_bitsandbytes = load_in_4bit, unsloth_vllm_standby = unsloth_vllm_standby, is_vision_model = is_vlm, + fp8_mode = fp8_mode, ) for allowed_arg in allowed_args: if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: @@ -865,6 +879,7 @@ def from_pretrained( llm, config = model_config, is_vision_model = is_vlm, + load_in_fp8 = load_in_fp8, ) model = convert_vllm_to_huggingface( quant_state_dict,