Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions unsloth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
import trl
import inspect
from trl import SFTTrainer
try:
from trl import SFTConfig as TrainingArguments
except:
from transformers import TrainingArguments
pass
from . import is_bfloat16_supported
from unsloth_zoo.training_utils import unsloth_train as _unsloth_train
from packaging.version import Version
Expand Down Expand Up @@ -60,7 +55,11 @@ def unsloth_train(trainer, *args, **kwargs):
pass
pass


try:
from trl import SFTConfig as TrainingArguments
except:
from transformers import TrainingArguments
pass
@dataclass
class UnslothTrainingArguments(TrainingArguments):
embedding_learning_rate : Optional[float] = field(
Expand Down Expand Up @@ -134,7 +133,7 @@ def create_optimizer(self):

# From `trl>=0.13.0`, they changed how to pass several params to the trainer
# We need to patch to make the transition smooth
def create_backwards_compatible_trainer(trainer_class, config_class):
def _backwards_compatible_trainer(trainer_class, config_class):
original_init = trainer_class.__init__

@wraps(original_init)
Expand Down Expand Up @@ -167,6 +166,7 @@ def new_init(self, *args, **kwargs):
}

# Get parameters that exist in Config but not in TrainingArguments
from transformers import TrainingArguments
moved_params = \
set(inspect.signature(config_class) .parameters.keys()) - \
set(inspect.signature(TrainingArguments).parameters.keys())
Expand Down Expand Up @@ -207,14 +207,13 @@ def _patch_trl_trainer():

import trl.trainer
trl_classes = dir(trl.trainer)

non_convertable_trainer = set(["PPOv2", "AlignProp"])
trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) - non_convertable_trainer
trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) - non_convertable_trainer
trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer"))
trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config"))
trl_classes = list(trl_trainers & trl_configs)

for x in trl_classes:
exec(f"trl.{x}Trainer.__init__ = create_backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals())
try: exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals())
except: continue
pass

trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True
Expand Down