Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,13 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
if name.endswith(".base_layer.weight"):
name = name[:-len(".base_layer.weight")]

# modules_to_save wraps embed_tokens / lm_head; strip the wrapper
# so the key matches lora_weights entries created by the branch above.
# Only strip .weight variant; the lora_weights branch adds both
# .weight and .bias from the module so we don't need a separate bias entry.
elif name.endswith(".modules_to_save.default.weight"):
name = name[:-len(".modules_to_save.default.weight")]

if name in lora_weights:
state_dict[name + ".weight"] = lora_weights[name]
if getattr(lora_weights[name].module, "bias", None) is not None:
Expand Down Expand Up @@ -1479,7 +1486,19 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp

def get_torch_storage_size_new(x, element_size):
if isinstance(x, LoraStats):
shape = (x.module.in_features, x.module.out_features)
mod = x.module
# modules_to_save: use the saved weight shape directly
saved_w = _get_modules_to_save_weight(mod)
if saved_w is None and hasattr(mod, "weight"):
saved_w = mod.weight
if saved_w is not None and hasattr(saved_w, "shape"):
return int(np.prod(saved_w.shape)) * element_size
# MoE LoRA wrappers with no .base_layer: infer merged shape from lora matrices
if mod is None and x.lora_A is not None and x.lora_B is not None:
shape = (x.lora_B.shape[0], x.lora_A.shape[1])
return int(np.prod(shape)) * element_size
# Fallback for Linear-like modules
shape = (mod.in_features, mod.out_features)
return int(np.prod(shape)) * element_size
else:
return get_torch_storage_size(x)
Expand Down