Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ jobs:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI with dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
51 changes: 49 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import datasets
import jinja2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -40,7 +41,7 @@
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction, seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging

Expand Down Expand Up @@ -614,11 +615,57 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

if self.args.save_strategy == "best":
self.control.should_save = is_new_best_metric

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
# This can be removed once the minimum transformers version is updated to 4.47.
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
def _determine_best_metric(self, metrics, trial):
"""
Determine if the model should be saved based on the evaluation metrics.
If args.metric_for_best_model is not set, the loss is used.
Returns:
bool: True if a new best metric was found, else False
"""
is_new_best_metric = False

if self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model

if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"

try:
metric_value = metrics[metric_to_check]
except KeyError as exc:
raise KeyError(
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
) from exc

operator = np.greater if self.args.greater_is_better else np.less

if self.state.best_metric is None:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

is_new_best_metric = True

return is_new_best_metric

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def repeat_generator():
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
torch.cuda.empty_cache()
gc.collect()
Expand Down