diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e78e66ad7..111d8c2be 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -845,7 +845,13 @@ def vllm_dynamic_quant_supported( pass -def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False): +def get_vllm_state_dict( + llm, + return_state_dict = False, + config = None, + is_vision_model = False, + load_in_fp8 = False, +): # If the vllm state dict was quantized using torchao, we will run into # the following error when calling ops like aten.t() in inference mode. # This is a bug in PyTorch that affects all tensor subclasses. @@ -854,7 +860,7 @@ def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision # # For now, we work around this issue by using torch.no_grad in this case. # See https://github.com/pytorch/pytorch/issues/164872 for more details - if get_quant_type(config) == "torchao": + if get_quant_type(config) == "torchao" or load_in_fp8: ctx_manager = torch.no_grad() else: ctx_manager = torch.inference_mode() @@ -1706,6 +1712,30 @@ def _module_disallows_flashinfer(module) -> bool: pass +def _get_torchao_fp8_config(fp8_mode: str): + """ + Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig` + to be used for `load_in_fp8=True`. + """ + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + PerBlock, + PerRow, + ) + + if fp8_mode == "row": + granularity = PerRow() + elif fp8_mode == "block": + granularity = (PerBlock([1, 128]), PerBlock([128, 128])) + else: + raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'") + + return Float8DynamicActivationFloat8WeightConfig( + granularity = granularity, + activation_value_lb = 1e-12, + ) + + def load_vllm( model_name : str = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", config = None, @@ -1731,6 +1761,7 @@ def load_vllm( is_vision_model : bool = False, return_args : bool = False, # Just return args max_num_seqs : int = 256, # how many seqs to process in parallel. Default vLLM 256 + fp8_mode : Optional[str] = None, ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance @@ -2171,6 +2202,18 @@ def load_vllm( engine_args["disable_cascade_attn"] = disable_cascade_attn pass + # On-the-fly quantization is added in https://github.com/vllm-project/vllm/pull/23014 + # This is only available in vllm >= 0.12.0. Older versions will do offline quantization + # by creating an FP8 checkpoint and loading it back in + if fp8_mode is not None and Version(vllm_version) >= Version("0.12.0"): + from torchao.core.config import config_to_dict + torchao_config = _get_torchao_fp8_config(fp8_mode) + hf_overrides = { + "quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)), + } + engine_args["quantization"] = "torchao" + engine_args["hf_overrides"] = hf_overrides + good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys() old_keys = list(engine_args.keys()) for key in old_keys: