Skip to content
Merged
30 changes: 27 additions & 3 deletions unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,7 +2769,13 @@ def unsloth_save_pretrained_torchao(
for _ in range(3):
gc.collect()

from transformers import AutoModel, AutoTokenizer, TorchAoConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TorchAoConfig,
AutoModelForImageTextToText,
AutoProcessor,
)
from torchao import quantize_

if torchao_config is None:
Expand All @@ -2781,14 +2787,25 @@ def unsloth_save_pretrained_torchao(
torchao_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type = torchao_config)

tokenizer = AutoTokenizer.from_pretrained(arguments["save_directory"])
is_vlm = False
if hasattr(self, "config") and hasattr(self.config, "architectures"):
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in self.config.architectures
)
is_vlm = is_vlm or hasattr(self.config, "vision_config")
auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM
auto_processor = AutoProcessor if is_vlm else AutoTokenizer

tokenizer = auto_processor.from_pretrained(arguments["save_directory"])

# TorchAO must only use bfloat16 for loading (float16 fails)
if HAS_TORCH_DTYPE:
kwargs = {"torch_dtype": torch.bfloat16}
else:
kwargs = {"dtype": torch.bfloat16}
model = AutoModel.from_pretrained(

model = auto_model.from_pretrained(
arguments["save_directory"],
device_map = "auto",
quantization_config = quantization_config,
Expand All @@ -2812,6 +2829,13 @@ def unsloth_save_pretrained_torchao(
torchao_save_directory, safe_serialization = safe_serialization
)
tokenizer.save_pretrained(torchao_save_directory)
if os.path.exists(save_directory):
try:
import shutil

shutil.rmtree(save_directory)
except:
pass
for _ in range(3):
gc.collect()

Expand Down