diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d0d76a840..8631a6c3a 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -542,22 +542,29 @@ def create_empty_causal_lm(config, dtype = torch.float16): @torch.inference_mode -def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16): +def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, bnb_config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model config.update({"torch_dtype" : dtype}) # Do not use config file's dtype! new_model = create_empty_causal_lm(config, dtype) quantization_config = getattr(config, "quantization_config", {}) kwargs = dict() - if quantization_config != {}: + compute_dtype = dtype # Do not use config file's dtype! + + if quantization_config != {} or bnb_config is not None: # Get quantization_config flags - compute_dtype = _get_dtype(quantization_config["bnb_4bit_compute_dtype"]) - compute_dtype = dtype # Do not use config file's dtype! - kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] - kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] - kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) - pass + if quantization_config != {}: + kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] + kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] + kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) + # Get bnb_config flags + elif bnb_config is not None: + kwargs["compress_statistics"] = bnb_config.bnb_4bit_use_double_quant + kwargs["quant_type"] = bnb_config.bnb_4bit_quant_type + kwargs["quant_storage"] = _get_dtype(bnb_config.bnb_4bit_quant_storage) + + pass from bitsandbytes.nn.modules import Linear4bit, Params4bit from torch.nn.modules import Linear