diff --git a/unsloth/save.py b/unsloth/save.py index 810bf9062a..0a8f02d90f 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2769,7 +2769,13 @@ def unsloth_save_pretrained_torchao( for _ in range(3): gc.collect() - from transformers import AutoModel, AutoTokenizer, TorchAoConfig + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TorchAoConfig, + AutoModelForImageTextToText, + AutoProcessor, + ) from torchao import quantize_ if torchao_config is None: @@ -2781,14 +2787,25 @@ def unsloth_save_pretrained_torchao( torchao_config = Int8DynamicActivationInt8WeightConfig() quantization_config = TorchAoConfig(quant_type = torchao_config) - tokenizer = AutoTokenizer.from_pretrained(arguments["save_directory"]) + is_vlm = False + if hasattr(self, "config") and hasattr(self.config, "architectures"): + is_vlm = any( + x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) + for x in self.config.architectures + ) + is_vlm = is_vlm or hasattr(self.config, "vision_config") + auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM + auto_processor = AutoProcessor if is_vlm else AutoTokenizer + + tokenizer = auto_processor.from_pretrained(arguments["save_directory"]) # TorchAO must only use bfloat16 for loading (float16 fails) if HAS_TORCH_DTYPE: kwargs = {"torch_dtype": torch.bfloat16} else: kwargs = {"dtype": torch.bfloat16} - model = AutoModel.from_pretrained( + + model = auto_model.from_pretrained( arguments["save_directory"], device_map = "auto", quantization_config = quantization_config, @@ -2812,6 +2829,13 @@ def unsloth_save_pretrained_torchao( torchao_save_directory, safe_serialization = safe_serialization ) tokenizer.save_pretrained(torchao_save_directory) + if os.path.exists(save_directory): + try: + import shutil + + shutil.rmtree(save_directory) + except: + pass for _ in range(3): gc.collect()