Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion unsloth_zoo/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def requires_grad_pre_hook(module, input):
raise RuntimeError("Unsloth: Failed to make input require gradients!")
# print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
# return
input[0].requires_grad_(True)
if torch.is_floating_point(input[0]):
input[0].requires_grad_(True)
else:
raise RuntimeError("Unsloth: Failed to make input require gradients!")
pass
Expand Down
33 changes: 32 additions & 1 deletion unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
pass
from transformers.modeling_utils import PushToHubMixin
import json
import os
from pathlib import Path
import tempfile
from peft import PeftModelForCausalLM
Expand Down Expand Up @@ -540,7 +541,13 @@ def merge_and_overwrite_lora(
model_name = model.config._name_or_path

# Find repository's max shard size and total size of everything
file_list = HfFileSystem(token = token).ls(model_name, detail = True)
try:
file_list = HfFileSystem(token = token).ls(model_name, detail = True)
except:
original_model_id = get_original_model_id(model_name)
model_name = original_model_id
file_list = HfFileSystem(token = token).ls(model_name, detail = True)

safetensors_list = []
max_size_in_bytes = 0
total_size_in_bytes = 0
Expand Down Expand Up @@ -909,6 +916,30 @@ def merge_lora_weights(state_dict, name):
pass
pass

def get_original_model_id(local_path: str):
import json
import os

config_path = os.path.join(local_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)

# Check for _name_or_path that's not a local path
# When we load using AutoConfig, the _name_or_path changed into the local path instead
if "_name_or_path" in config:
return config["_name_or_path"]

config_path = os.path.join(local_path, "adapter_config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)

if "base_model_name_or_path" in config:
return config["base_model_name_or_path"]

return None

# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
Expand Down