Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *

# patch sft trainer
_patch_trl_trainer()
112 changes: 102 additions & 10 deletions unsloth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from typing import Optional
from functools import wraps

import trl
import inspect
from trl import SFTTrainer
try:
from trl import SFTConfig as TrainingArguments
Expand All @@ -24,30 +28,38 @@
from . import is_bfloat16_supported
from unsloth_zoo.training_utils import unsloth_train as _unsloth_train
from packaging.version import Version
import dataclasses

__all__ = [
"UnslothTrainingArguments",
"UnslothTrainer",
"unsloth_train",
"_patch_trl_trainer",
]

# Unsloth gradient accumulation fix:
from transformers import __version__ as transformers_version
if Version(transformers_version) > Version("4.45.2"):
def unsloth_train(trainer):
return trainer.train()
def unsloth_train(trainer, *args, **kwargs):
return trainer.train(*args, **kwargs)
pass
else:
def unsloth_train(trainer):
def unsloth_train(trainer, *args, **kwargs):
if len(args) != 0 or len(kwargs) != 0:
raise RuntimeError(
"Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
)
print(
"Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
)
return _unsloth_train(trainer)
pass
pass

__all__ = [
"UnslothTrainingArguments",
"UnslothTrainer",
"unsloth_train",
]


@dataclass
class UnslothTrainingArguments(TrainingArguments):
Expand Down Expand Up @@ -119,3 +131,83 @@ def create_optimizer(self):
return self.optimizer
pass
pass

# From `trl>=0.13.0`, they changed how to pass several params to the trainer
# We need to patch to make the transition smooth
def create_backwards_compatible_trainer(trainer_class, config_class):
original_init = trainer_class.__init__

@wraps(original_init)
def new_init(self, *args, **kwargs):
# All Trainer tokenizer is now called processing_class
if "tokenizer" in kwargs:
kwargs["processing_class"] = kwargs.pop("tokenizer")

if "args" in kwargs:
training_args = kwargs.pop("args", None)

# Get parameters that Trainer.__init__ actually expects
trainer_params = set(inspect.signature(original_init).parameters.keys())
trainer_params.remove('self')
trainer_params.remove('args')

# Get fields that should be passed to Config init
config_fields = {
field.name: field for field in dataclasses.fields(config_class)
if field.init
}

# Create config dict with valid fields from training_args
config_dict = {
name: getattr(training_args, name)
for name in config_fields
if hasattr(training_args, name)
}

# Get parameters that exist in Config but not in TrainingArguments
moved_params = \
set(inspect.signature(config_class) .parameters.keys()) - \
set(inspect.signature(TrainingArguments).parameters.keys())

# Separate kwargs into trainer kwargs and config kwargs
trainer_kwargs = {}
additional_config_kwargs = {}

for key, value in kwargs.items():
if key in trainer_params: trainer_kwargs[key] = value
elif key in moved_params or key in config_fields:
additional_config_kwargs[key] = value
else:
additional_config_kwargs[key] = value
pass

# Update config_dict with additional kwargs
config_dict.update(additional_config_kwargs)

# Create Config with all the collected parameters
config = config_class(**config_dict)

# Reconstruct kwargs for Trainer
kwargs = trainer_kwargs
kwargs["args"] = config
pass
original_init(self, *args, **kwargs)
pass
return new_init

if Version(trl.__version__) >= Version("0.13.0.dev0"):
# print("Patching TRL Trainer to maintain backward compatibility with the old syntax.")
def _patch_trl_trainer():
import trl.trainer
trl_classes = dir(trl.trainer)

non_convertable_trainer = set(["PPOv2", "AlignProp"])
trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) - non_convertable_trainer
trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) - non_convertable_trainer
trl_classes = list(trl_trainers & trl_configs)
for x in trl_classes:
exec(f"trl.{x}Trainer.__init__ = create_backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals())
pass
else:
def _patch_trl_trainer(): return
pass