Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
650 commits
Select commit Hold shift + click to select a range
874ccc8
Update compiler.py
danielhanchen Mar 11, 2025
10b18ea
debugging
danielhanchen Mar 11, 2025
c79a57d
remove debugging
danielhanchen Mar 11, 2025
33f9482
num items in batch
danielhanchen Mar 11, 2025
3c49dd4
Update compiler.py
danielhanchen Mar 11, 2025
cf10a61
Update compiler.py
danielhanchen Mar 11, 2025
441771c
Update compiler.py
danielhanchen Mar 11, 2025
4c560a3
Update compiler.py
danielhanchen Mar 11, 2025
d43a386
Update compiler.py
danielhanchen Mar 11, 2025
c3cc10e
Update compiler.py
danielhanchen Mar 11, 2025
b5dbd89
Update compiler.py
danielhanchen Mar 11, 2025
2d9f12a
Update compiler.py
danielhanchen Mar 11, 2025
8330a31
Update compiler.py
danielhanchen Mar 11, 2025
d5623c2
Update compiler.py
danielhanchen Mar 11, 2025
7403751
Update compiler.py
danielhanchen Mar 11, 2025
a87ad7c
Update compiler.py
danielhanchen Mar 11, 2025
4cf6bb4
Update compiler.py
danielhanchen Mar 11, 2025
c66150b
Update compiler.py
danielhanchen Mar 11, 2025
9b98570
Update compiler.py
danielhanchen Mar 11, 2025
b315f39
Update compiler.py
danielhanchen Mar 11, 2025
ae1a2fd
Update compiler.py
danielhanchen Mar 11, 2025
d7b08e8
Update compiler.py
danielhanchen Mar 11, 2025
27e3fd1
Update compiler.py
danielhanchen Mar 11, 2025
c97ffda
Update compiler.py
danielhanchen Mar 11, 2025
b7a84d4
Update compiler.py
danielhanchen Mar 11, 2025
c0b1879
Update compiler.py
danielhanchen Mar 11, 2025
ac41081
Update compiler.py
danielhanchen Mar 11, 2025
ebb3109
Update compiler.py
danielhanchen Mar 11, 2025
de2e580
Update compiler.py
danielhanchen Mar 11, 2025
6e38eec
logs
danielhanchen Mar 11, 2025
c7dfd06
Update patching_utils.py
danielhanchen Mar 11, 2025
f35ada1
VLM attention mask
danielhanchen Mar 12, 2025
8efec06
Update loss_utils.py
danielhanchen Mar 12, 2025
75b2e9e
Update loss_utils.py
danielhanchen Mar 12, 2025
bac14cb
Update loss_utils.py
danielhanchen Mar 12, 2025
d791919
Update loss_utils.py
danielhanchen Mar 12, 2025
b5f9d32
Update loss_utils.py
danielhanchen Mar 12, 2025
4fe56b6
Update loss_utils.py
danielhanchen Mar 12, 2025
d6187fe
Recheck
danielhanchen Mar 12, 2025
c9eeece
Update compiler.py
danielhanchen Mar 12, 2025
59e7860
Update patching_utils.py
danielhanchen Mar 12, 2025
c4945dd
Update patching_utils.py
danielhanchen Mar 12, 2025
d293453
Update patching_utils.py
danielhanchen Mar 12, 2025
5afbb3e
Update patching_utils.py
danielhanchen Mar 12, 2025
529a926
Update compiler.py
danielhanchen Mar 12, 2025
c8f14ce
Update patching_utils.py
danielhanchen Mar 12, 2025
97d8190
suppress errors
danielhanchen Mar 12, 2025
bf36a7e
Update compiler.py
danielhanchen Mar 12, 2025
2f6d5ec
Update patching_utils.py
danielhanchen Mar 12, 2025
1339306
Update compiler.py
danielhanchen Mar 12, 2025
bee764b
Update patching_utils.py
danielhanchen Mar 12, 2025
800077f
Update patching_utils.py
danielhanchen Mar 12, 2025
83ae6be
Update patching_utils.py
danielhanchen Mar 12, 2025
74c40ab
Update peft_utils.py
danielhanchen Mar 12, 2025
d37e823
Update compiler.py
danielhanchen Mar 12, 2025
e4869ff
Update loss_utils.py
danielhanchen Mar 12, 2025
08c4a4f
Update loss_utils.py
danielhanchen Mar 12, 2025
4e8773a
Merge branch 'main' into nightly
danielhanchen Mar 12, 2025
ea79be4
Merge branch 'main' into nightly
danielhanchen Mar 12, 2025
8bda25a
Update compiler.py
danielhanchen Mar 12, 2025
09f9c7e
Update compiler.py
danielhanchen Mar 12, 2025
acf74ec
Update compiler.py
danielhanchen Mar 12, 2025
c21c990
Update compiler.py
danielhanchen Mar 12, 2025
164cdea
Update compiler.py
danielhanchen Mar 12, 2025
d549aa6
Update compiler.py
danielhanchen Mar 12, 2025
cfb6851
Update compiler.py
danielhanchen Mar 12, 2025
3344d4e
bug fixes
danielhanchen Mar 12, 2025
1d45bfa
Update compiler.py
danielhanchen Mar 12, 2025
de6c061
Update compiler.py
danielhanchen Mar 12, 2025
28e9318
Update vision_utils.py
danielhanchen Mar 12, 2025
bf60148
Update compiler.py
danielhanchen Mar 12, 2025
7739c37
Update loss_utils.py
danielhanchen Mar 12, 2025
3192f8d
Update loss_utils.py
danielhanchen Mar 12, 2025
9f6f012
Update loss_utils.py
danielhanchen Mar 12, 2025
e39740c
Update loss_utils.py
danielhanchen Mar 12, 2025
6e17816
Bug fixes
danielhanchen Mar 13, 2025
f2f1a2e
Update dataset_utils.py
danielhanchen Mar 13, 2025
9889307
Update dataset_utils.py
danielhanchen Mar 13, 2025
b56b523
Update dataset_utils.py
danielhanchen Mar 13, 2025
a7c257a
Update dataset_utils.py
danielhanchen Mar 13, 2025
6556f13
Update dataset_utils.py
danielhanchen Mar 13, 2025
9a05b2f
Update compiler.py
danielhanchen Mar 13, 2025
0f6fc7a
Update compiler.py
danielhanchen Mar 13, 2025
4bb152a
Update compiler.py
danielhanchen Mar 13, 2025
3ccdf86
Update compiler.py
danielhanchen Mar 13, 2025
1783ba1
Update loss_utils.py
danielhanchen Mar 13, 2025
b2983f4
Update loss_utils.py
danielhanchen Mar 13, 2025
3c9c731
Merge branch 'main' into nightly
danielhanchen Mar 13, 2025
4c5c77d
gpu_memory_utilization
danielhanchen Mar 14, 2025
b918327
Update temporary_patches.py
danielhanchen Mar 14, 2025
e8f561c
Update vision_utils.py
danielhanchen Mar 14, 2025
4459ef8
Update vision_utils.py
danielhanchen Mar 14, 2025
62e0e14
Update vision_utils.py
danielhanchen Mar 14, 2025
9f4b729
Update vision_utils.py
danielhanchen Mar 14, 2025
29a7abf
Update vision_utils.py
danielhanchen Mar 14, 2025
9830edd
Update vision_utils.py
danielhanchen Mar 14, 2025
be53fda
Update vision_utils.py
danielhanchen Mar 14, 2025
28f4df4
Update vision_utils.py
danielhanchen Mar 14, 2025
ad13d0a
train on completions VLMs
danielhanchen Mar 14, 2025
370cbd7
Update dataset_utils.py
danielhanchen Mar 14, 2025
bd60d26
Update dataset_utils.py
danielhanchen Mar 14, 2025
29ed559
Update dataset_utils.py
danielhanchen Mar 14, 2025
e0a4416
Update dataset_utils.py
danielhanchen Mar 14, 2025
d6e55ca
VLM train only on completions
danielhanchen Mar 14, 2025
adf8307
Update loss_utils.py
danielhanchen Mar 14, 2025
98d5885
Update dataset_utils.py
danielhanchen Mar 14, 2025
967c2ba
Update compiler.py
danielhanchen Mar 14, 2025
ddf2b8e
Update compiler.py
danielhanchen Mar 14, 2025
cb2f6c7
Update compiler.py
danielhanchen Mar 14, 2025
873c514
Update compiler.py
danielhanchen Mar 14, 2025
ca0b499
Update compiler.py
danielhanchen Mar 14, 2025
4908a16
Update compiler.py
danielhanchen Mar 14, 2025
1d4b5d7
Update compiler.py
danielhanchen Mar 14, 2025
81b45c6
Update saving_utils.py
danielhanchen Mar 14, 2025
261ffd2
Update llama_cpp.py
danielhanchen Mar 14, 2025
2ed281a
Update llama_cpp.py
danielhanchen Mar 14, 2025
d89a8fa
Update saving_utils.py
danielhanchen Mar 14, 2025
106736a
Update saving_utils.py
danielhanchen Mar 14, 2025
4abfdcd
Update __init__.py
danielhanchen Mar 14, 2025
0ac4464
Update compiler.py
danielhanchen Mar 14, 2025
e2fbe79
Update loss_utils.py
danielhanchen Mar 14, 2025
82665d4
Update compiler.py
danielhanchen Mar 14, 2025
9b6142e
Update loss_utils.py
danielhanchen Mar 14, 2025
ee92817
Update loss_utils.py
danielhanchen Mar 14, 2025
9b7600d
Update llama_cpp.py
danielhanchen Mar 14, 2025
d5b6d1c
Update loss_utils.py
danielhanchen Mar 14, 2025
86516ad
Update compiler.py
danielhanchen Mar 14, 2025
29553e4
Update llama_cpp.py
danielhanchen Mar 14, 2025
33e6c8e
Update compiler.py
danielhanchen Mar 14, 2025
5202605
Update vllm_utils.py
danielhanchen Mar 14, 2025
ca52896
Update rl_replacements.py
danielhanchen Mar 14, 2025
7baa442
Update rl_replacements.py
danielhanchen Mar 14, 2025
7ff5a1a
Update rl_replacements.py
danielhanchen Mar 14, 2025
e80aa10
Update rl_replacements.py
danielhanchen Mar 14, 2025
e93d93f
Update rl_replacements.py
danielhanchen Mar 14, 2025
9a6c231
Update rl_replacements.py
danielhanchen Mar 14, 2025
c8abd45
Update training_utils.py
danielhanchen Mar 14, 2025
1633c78
Merge branch 'main' into nightly
danielhanchen Mar 15, 2025
964129b
Update dataset_utils.py
danielhanchen Mar 15, 2025
3b690ad
Update dataset_utils.py
danielhanchen Mar 16, 2025
7bb4a13
Revert "Update dataset_utils.py"
danielhanchen Mar 16, 2025
947c5e9
Update temporary_patches.py
danielhanchen Mar 16, 2025
2fe9c6c
Update temporary_patches.py
danielhanchen Mar 16, 2025
0b2dc97
Update temporary_patches.py
danielhanchen Mar 16, 2025
b9a96dc
Update temporary_patches.py
danielhanchen Mar 16, 2025
0784a07
Update temporary_patches.py
danielhanchen Mar 16, 2025
80c2dc8
Update temporary_patches.py
danielhanchen Mar 16, 2025
26c817d
Update compiler.py
danielhanchen Mar 16, 2025
d3cdd17
Update compiler.py
danielhanchen Mar 16, 2025
31e778a
Remove prints
danielhanchen Mar 16, 2025
2c6a3c5
Update compiler.py
danielhanchen Mar 16, 2025
f3f3c9c
Update saving_utils.py
danielhanchen Mar 16, 2025
93b6a88
Update temporary_patches.py
danielhanchen Mar 16, 2025
86aee5c
Update __init__.py
danielhanchen Mar 16, 2025
ac38bff
Update pyproject.toml
danielhanchen Mar 16, 2025
f64e153
Update vllm_utils.py
danielhanchen Mar 16, 2025
4c72e79
bug fix #2008 unsloth issue - load_in_4bit = True + fast_inference = …
void-mckenzie Mar 16, 2025
1974798
Update dataset_utils.py
danielhanchen Mar 16, 2025
4df4417
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 16, 2025
a5c20e1
Update compiler.py
danielhanchen Mar 17, 2025
a434d45
Update temporary_patches.py
danielhanchen Mar 17, 2025
3cfb98f
Gemma 3 fixes
danielhanchen Mar 17, 2025
fc5f1c0
Update temporary_patches.py
danielhanchen Mar 17, 2025
b317e90
Update compiler.py
danielhanchen Mar 17, 2025
4121dd0
Update compiler.py
danielhanchen Mar 17, 2025
c59dcde
Gemma 3 fixes
danielhanchen Mar 17, 2025
d98ae2e
Update patching_utils.py
danielhanchen Mar 17, 2025
3073ea3
Update compiler.py
danielhanchen Mar 17, 2025
57ff5f6
Update compiler.py
danielhanchen Mar 17, 2025
c7e803b
Update patching_utils.py
danielhanchen Mar 17, 2025
3daaf0d
Update temporary_patches.py
danielhanchen Mar 17, 2025
b619b58
Update compiler.py
danielhanchen Mar 17, 2025
4e78082
Update compiler.py
danielhanchen Mar 17, 2025
c8ba677
Update temporary_patches.py
danielhanchen Mar 17, 2025
fb68ecc
Update temporary_patches.py
danielhanchen Mar 17, 2025
e5a73fe
Update temporary_patches.py
danielhanchen Mar 17, 2025
d7bbe30
Update temporary_patches.py
danielhanchen Mar 17, 2025
5f99275
Update temporary_patches.py
danielhanchen Mar 17, 2025
346812f
Update temporary_patches.py
danielhanchen Mar 17, 2025
b907d0c
Update temporary_patches.py
danielhanchen Mar 17, 2025
789171c
Update temporary_patches.py
danielhanchen Mar 17, 2025
4e2c94a
Update compiler.py
danielhanchen Mar 17, 2025
4740c99
Update compiler.py
danielhanchen Mar 17, 2025
4658d94
Update compiler.py
danielhanchen Mar 17, 2025
f9de6e9
Update compiler.py
danielhanchen Mar 17, 2025
dbdbc63
Update compiler.py
danielhanchen Mar 17, 2025
55b1963
Update compiler.py
danielhanchen Mar 17, 2025
e997ee1
Update compiler.py
danielhanchen Mar 17, 2025
0ba033f
Update compiler.py
danielhanchen Mar 17, 2025
bf821ba
Update compiler.py
danielhanchen Mar 17, 2025
d8c6e59
Update compiler.py
danielhanchen Mar 17, 2025
9967ce3
Update compiler.py
danielhanchen Mar 17, 2025
7b0c535
Update compiler.py
danielhanchen Mar 17, 2025
e6859ce
Update compiler.py
danielhanchen Mar 17, 2025
b2a8f47
Update compiler.py
danielhanchen Mar 17, 2025
ca79c93
Update compiler.py
danielhanchen Mar 17, 2025
3f67ed6
Update compiler.py
danielhanchen Mar 17, 2025
e5fb044
Update compiler.py
danielhanchen Mar 18, 2025
4a1bf2f
Update compiler.py
danielhanchen Mar 18, 2025
36ec4ee
Update compiler.py
danielhanchen Mar 18, 2025
7d1dc81
compiler
danielhanchen Mar 18, 2025
16d6137
Update gradient_checkpointing.py
danielhanchen Mar 18, 2025
9b78566
Update temporary_patches.py
danielhanchen Mar 18, 2025
e0edefe
Update temporary_patches.py
danielhanchen Mar 18, 2025
719e379
Update temporary_patches.py
danielhanchen Mar 18, 2025
8beb2b7
Update temporary_patches.py
danielhanchen Mar 18, 2025
f9cf701
Update temporary_patches.py
danielhanchen Mar 18, 2025
aa8848c
Update temporary_patches.py
danielhanchen Mar 18, 2025
ee940a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
d086158
Update temporary_patches.py
danielhanchen Mar 18, 2025
5a43de2
Update temporary_patches.py
danielhanchen Mar 18, 2025
1f6589b
Update temporary_patches.py
danielhanchen Mar 18, 2025
9b904a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
3c0504b
Update temporary_patches.py
danielhanchen Mar 18, 2025
417161e
Update temporary_patches.py
danielhanchen Mar 18, 2025
3f024b6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b0bd2f4
Update temporary_patches.py
danielhanchen Mar 18, 2025
640e071
Update temporary_patches.py
danielhanchen Mar 18, 2025
05c2232
Update temporary_patches.py
danielhanchen Mar 18, 2025
593eecb
Update temporary_patches.py
danielhanchen Mar 18, 2025
e9c935f
Update temporary_patches.py
danielhanchen Mar 18, 2025
b71160c
causal mask dtype
danielhanchen Mar 18, 2025
a6fedb6
Fix checkpoint and save from local file (#74)
Erland366 Mar 18, 2025
c566b02
Update patching_utils.py
danielhanchen Mar 18, 2025
aaf5feb
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 18, 2025
94f5f4f
Update patching_utils.py
danielhanchen Mar 18, 2025
26c67cf
Update temporary_patches.py
danielhanchen Mar 18, 2025
d92bab6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b04cf4b
Update compiler.py
danielhanchen Mar 18, 2025
4565db3
Update loss_utils.py
danielhanchen Mar 18, 2025
e368810
Update compiler.py
danielhanchen Mar 18, 2025
ce07e0f
Update vllm_utils.py
danielhanchen Mar 18, 2025
114150d
Update compiler.py
danielhanchen Mar 18, 2025
6bd69f1
Update peft_utils.py
danielhanchen Mar 18, 2025
9cee216
Update rl_replacements.py
danielhanchen Mar 18, 2025
df8ac03
Update vllm_utils.py
danielhanchen Mar 18, 2025
e5a321f
Update temporary_patches.py
danielhanchen Mar 18, 2025
134857d
Update temporary_patches.py
danielhanchen Mar 18, 2025
dec6433
Update temporary_patches.py
danielhanchen Mar 18, 2025
b14149b
Update temporary_patches.py
danielhanchen Mar 18, 2025
07f7dde
Update temporary_patches.py
danielhanchen Mar 18, 2025
7600d35
Update temporary_patches.py
danielhanchen Mar 18, 2025
679edeb
Update temporary_patches.py
danielhanchen Mar 18, 2025
5fd25ec
Update temporary_patches.py
danielhanchen Mar 18, 2025
a884b3c
Update temporary_patches.py
danielhanchen Mar 18, 2025
b6ab8bd
Update temporary_patches.py
danielhanchen Mar 18, 2025
cc3ca48
Update temporary_patches.py
danielhanchen Mar 18, 2025
9f5b67d
Update temporary_patches.py
danielhanchen Mar 18, 2025
e4980b2
Update temporary_patches.py
danielhanchen Mar 18, 2025
d745fb7
Update temporary_patches.py
danielhanchen Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ classifiers = [
dependencies = [
"torch",
"triton ; platform_system == 'Linux'",
"packaging",
"triton_windows ; platform_system == 'Windows'",
"packaging>=24.1",
"tyro",
"transformers>=4.46.1,!=4.47.0",
"datasets>=2.16.0",
Expand Down
2 changes: 1 addition & 1 deletion unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__version__ = "2025.3.12"
__version__ = "2025.3.13"

from importlib.util import find_spec
if find_spec("unsloth") is None:
Expand Down
62 changes: 52 additions & 10 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def get_transformers_model_type(
config = str(config.to_dict())
model_types = re.findall(r"'model_type': '([^\s\']{1,})'", config)
model_types = [x.replace("-", "_").lower() for x in model_types]
# Add splitted modules for eg gemma3_text -> gemma3
model_types += [x.split("_")[0] for x in model_types]
model_types = list(dict().fromkeys(model_types))

from transformers import models
models = dir(models)
Expand Down Expand Up @@ -1213,6 +1216,34 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):

"""

COMPILED_LORA_FORWARD_forced_float32 = """
torch_addmm = torch.addmm
torch_add = torch.add
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t()
# output = result + scaling * xA @ lora_B.weight.t()
shape = result.shape
output = torch_addmm(
result.view(-1, shape[-1]),
xA.view(-1, xA.shape[-1]),
lora_B.weight.to(torch.float16).t(),
alpha = scaling,
beta = 1,
).view(shape)

bias = lora_B.bias
if bias is not None:
output = torch_add(
output,
bias.to(torch.float16),
alpha = scaling,
)
return output
pass

"""

def patch_lora_forwards(torch_compile_options):
# All Unsloth Zoo code licensed under LGPLv3
Linear_LoRA_Layers = get_lora_layer_modules()
Expand Down Expand Up @@ -1254,26 +1285,37 @@ def patch_lora_forwards(torch_compile_options):
)

# Check failed upcasting
if "torch.is_autocast_enabled()" not in source:
source = source.replace(
"x = x.to(lora_A.weight.dtype)",
"if not torch.is_autocast_enabled(): "\
"result, x = "\
"result.to(lora_A.weight.dtype), "\
"x.to(lora_A.weight.dtype)"
)
replacements = [
"x = x.to(lora_A.weight.dtype)",
"x = self._cast_input_dtype(x, lora_A.weight.dtype)",
]
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
if "torch.is_autocast_enabled()" not in source:
new = "if not torch.is_autocast_enabled(): "\
"result, x = "\
"result.to(lora_A.weight.dtype), "\
"x.to(lora_A.weight.dtype)"
for replace in replacements:
source = source.replace(replace, new)
else:
for replace in replacements:
source = source.replace(replace, "")
pass

source = source.replace(
"self._check_forward_args(x, *args, **kwargs)",
"",
)

if hash(source) != old_hash:
success += 1
compiled_lora_forward = \
COMPILED_LORA_FORWARD \
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \
else COMPILED_LORA_FORWARD_forced_float32

forward = create_new_function(
f"{child}_peft_forward",
COMPILED_LORA_FORWARD + source,
compiled_lora_forward + source,
parent,
dir(eval(parent)),
prepend = \
Expand Down
1 change: 1 addition & 0 deletions unsloth_zoo/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def sft_prepare_dataset(
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
dataset_text_field = getattr(args, "dataset_text_field", "text")
do_truncation = max_seq_length != 0
do_formatting_func = False
Expand Down
2 changes: 1 addition & 1 deletion unsloth_zoo/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def fused_linear_cross_entropy(
reduction = "sum" if num_items_in_batch is not None else "mean"
if logit_softcapping == 0: logit_softcapping = None
loss = linear_cross_entropy(
hidden_states,
hidden_states.to(lm_weight.dtype),
lm_weight,
targets = labels,
ignore_index = ignore_index,
Expand Down
45 changes: 44 additions & 1 deletion unsloth_zoo/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,13 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True):
pass


def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embeddings = True):
def patch_model_and_tokenizer(
model,
tokenizer,
downcast_rope = True,
fix_embeddings = True,
do_forced_float32 = False,
):
# All Unsloth Zoo code licensed under LGPLv3
assert(type(downcast_rope) is bool)
import gc
Expand Down Expand Up @@ -221,7 +227,44 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embedd
correct_dtype = _get_dtype(model.config.torch_dtype)
except:
correct_dtype = model.get_input_embeddings().weight.dtype
# If we force float32, we first use bfloat16, then downcast to float16
if do_forced_float32:
correct_dtype = torch.float16
for name, module in model.named_modules():
if "down_proj" in name or "up_proj" in name or "gate_proj" in name:
exec(f"module.to(torch.float16)")
if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name:
exec(f"module.to(torch.float16)")
if "lm_head" in name or "embed_tokens" in name:
exec(f"module.to(torch.float16)")
if "norm" in name:
exec(f"module.to(torch.float32)")
assert(module.weight.dtype == torch.float32)
torch.cuda.empty_cache()
pass

# Correct torch_dtype
def __fix_dtype(config):
if not hasattr(config, "to_dict"): return
dicts = config.to_dict()
for key, value in dicts.items():
if key == "torch_dtype":
setattr(config, "torch_dtype", torch.float16)
else:
__fix_dtype(getattr(config, key))
m = model
while hasattr(m, "model"):
if hasattr(m, "dtype"):
try: setattr(m, "dtype", torch.float16)
except: pass
if hasattr(m, "config"): __fix_dtype(m.config)
m = m.model
pass
if hasattr(m, "config"): __fix_dtype(m.config)
if hasattr(m, "dtype"):
try: setattr(m, "dtype", torch.float16)
except: pass
pass
# Check all params and patch!
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
Expand Down
7 changes: 6 additions & 1 deletion unsloth_zoo/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def requires_grad_pre_hook(module, input):
raise RuntimeError("Unsloth: Failed to make input require gradients!")
# print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
# return
input[0].requires_grad_(True)
if torch.is_floating_point(input[0]):
input[0].requires_grad_(True)
else:
raise RuntimeError("Unsloth: Failed to make input require gradients!")
pass
Expand Down Expand Up @@ -247,6 +248,10 @@ def requires_grad_pre_hook(module, input):
if f"in self.{module_list}:" in forward:
final_where = j
break
elif re.search(r"for [^\s]{3,} in self\." + module_list, forward) is not None:
# Might have failed finding self.layers: like self.layers[...]:
final_where = j
break
pass
pass
pass
Expand Down
6 changes: 3 additions & 3 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages)
n_mask_per_reward = mask.sum(1)

# See https://github.com/huggingface/trl/pull/2881
# loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
# loss = loss_per_reward.mean()
loss = (loss_i * mask).sum() / mask.sum()
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
loss = loss_per_reward.mean()
# loss = (loss_i * mask).sum() / mask.sum()

# Get metrics as well which are folded
with torch.inference_mode():
Expand Down
33 changes: 32 additions & 1 deletion unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
pass
from transformers.modeling_utils import PushToHubMixin
import json
import os
from pathlib import Path
import tempfile
from peft import PeftModelForCausalLM
Expand Down Expand Up @@ -540,7 +541,13 @@ def merge_and_overwrite_lora(
model_name = model.config._name_or_path

# Find repository's max shard size and total size of everything
file_list = HfFileSystem(token = token).ls(model_name, detail = True)
try:
file_list = HfFileSystem(token = token).ls(model_name, detail = True)
except:
original_model_id = get_original_model_id(model_name)
model_name = original_model_id
file_list = HfFileSystem(token = token).ls(model_name, detail = True)

safetensors_list = []
max_size_in_bytes = 0
total_size_in_bytes = 0
Expand Down Expand Up @@ -909,6 +916,30 @@ def merge_lora_weights(state_dict, name):
pass
pass

def get_original_model_id(local_path: str):
import json
import os

config_path = os.path.join(local_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)

# Check for _name_or_path that's not a local path
# When we load using AutoConfig, the _name_or_path changed into the local path instead
if "_name_or_path" in config:
return config["_name_or_path"]

config_path = os.path.join(local_path, "adapter_config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)

if "base_model_name_or_path" in config:
return config["base_model_name_or_path"]

return None

# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
Expand Down
Loading