Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"triton",
"packaging",
"tyro",
"transformers>=4.44.2",
"transformers>=4.46.1",
"datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
Expand Down
14 changes: 11 additions & 3 deletions unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@

__version__ = "2024.11.1"

import importlib.util
if importlib.util.find_spec("unsloth") is None:
from importlib.util import find_spec
if find_spec("unsloth") is None:
raise ImportError("Please install Unsloth via `pip install unsloth`!")
pass
del importlib.util
del find_spec

import os
if not ("UNSLOTH_IS_PRESENT" in os.environ):
raise ImportError("Please install Unsloth via `pip install unsloth`!")
pass

try:
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
except:
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
pass
# Log Unsloth-Zoo Utilities
os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1"
del os
67 changes: 65 additions & 2 deletions unsloth_zoo/gradient_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
"prepare_n_gradient_checkpoints",
"Unsloth_Offloaded_Gradient_Checkpointer",
"unsloth_offloaded_gradient_checkpoint",
"patch_unsloth_gradient_checkpointing",
"unpatch_unsloth_gradient_checkpointing",

"Unsloth_Gradient_Checkpointer",
"unsloth_gradient_checkpoint",
"patch_gradient_checkpointing",
"unpatch_gradient_checkpointing",
]
Expand Down Expand Up @@ -155,6 +160,35 @@ def backward(ctx, dY):
pass


class Unsloth_Gradient_Checkpointer(torch.autograd.Function):
"""
Same as normal gradient checkpointing but cleaner
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass

@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (None, hidden_states.grad,) + (None,)*len(ctx.args)
pass
pass


def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
pass
Expand All @@ -166,15 +200,44 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None,
pass


def patch_gradient_checkpointing():
print("Unsloth: Patching Gradient Checkpointing with Unsloth's special version!")
def unsloth_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
return Unsloth_Gradient_Checkpointer.apply(function, *args)
pass
if (Version(torch.__version__) < Version("2.4.0")) and \
not hasattr(unsloth_gradient_checkpoint, "__wrapped__"):
unsloth_gradient_checkpoint = torch._disable_dynamo(
unsloth_gradient_checkpoint
)
pass


def patch_unsloth_gradient_checkpointing():
print("Unsloth: Patched gradient checkpointing for long context finetuning.")
import torch.utils
if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return
torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint = unsloth_offloaded_gradient_checkpoint
pass


def patch_gradient_checkpointing():
print("Unsloth: Patched gradient checkpointing.")
import torch.utils
if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_gradient_checkpoint": return
torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint = unsloth_gradient_checkpoint
pass


def unpatch_unsloth_gradient_checkpointing():
import torch.utils
if hasattr(torch.utils.checkpoint, "_old_checkpoint"):
torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint
del torch.utils.checkpoint._old_checkpoint
pass
pass


def unpatch_gradient_checkpointing():
import torch.utils
if hasattr(torch.utils.checkpoint, "_old_checkpoint"):
Expand Down
79 changes: 45 additions & 34 deletions unsloth_zoo/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,74 @@

import torch
from packaging.version import Version
torch_nn_functional_cross_entropy = torch.nn.functional.cross_entropy

__all__ = [
"causal_loss_function",
"transformers_losses_patcher",
"patch_loss_function",
"patch_loss_functions",
"post_patch_loss_function",
]


def causal_loss_function(_fast_cross_entropy_loss):
def patch_loss_functions(_fast_cross_entropy_loss):
try:
import transformers.loss.loss_utils
except:
print("Unsloth: Cannot patch loss functions - update transformers for faster modules!")
return None
pass

# Generic cross entropy loss
def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
if ignore_index == -100:
loss = _fast_cross_entropy_loss(
logits = source,
labels = target,
n_items = num_items_in_batch,
)
else:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = torch_nn_functional_cross_entropy(
source,
target,
ignore_index = ignore_index,
reduction = reduction,
)
if reduction == "sum": loss = loss / num_items_in_batch
return loss
pass

# Causal LM loss
def UnslothForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
shift_logits = logits
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
loss = _fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
n_items = num_items_in_batch,
)
shift_labels[..., -1] = ignore_index
loss = unsloth_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
pass

if (Version(torch.__version__) < Version("2.4.0")):
UnslothForCausalLMLoss = torch._disable_dynamo(UnslothForCausalLMLoss)
pass
return UnslothForCausalLMLoss
pass


def transformers_losses_patcher(UnslothForCausalLMLoss):
def _patch_transformers_losses():
import re
try:
import transformers.loss.loss_utils
except:
print("Unsloth: Cannot patch loss functions - update transformers for faster modules!")
return
pass
# Now patch the losses!
import transformers.modeling_utils
LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING
LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss

import transformers.modeling_utils
LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING
LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss

# Remove @property and @lru_cache
if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"):
transformers.modeling_utils.PreTrainedModel.loss_function = \
transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__
pass
print("Unsloth: Patched cross entropy losses.")
# Remove @property and @lru_cache
if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget") and \
hasattr(transformers.modeling_utils.PreTrainedModel.loss_function.fget, "__wrapped__"):
transformers.modeling_utils.PreTrainedModel.loss_function = \
transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__
pass
return _patch_transformers_losses
print("Unsloth: Patched cross entropy losses.")
pass


def patch_loss_function(model):
def post_patch_loss_function(model):
try:
# model.loss_function starts as a dict to a loss fx
# We invoke it to save it
Expand Down
Loading