diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6ba5bcab427..ffa634ad15c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -112,4 +112,4 @@ jobs: slack_channel: ${{ env.CI_SLACK_CHANNEL }} title: 🤗 Results of the TRL CI with dev dependencies status: ${{ job.status }} - slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} \ No newline at end of file + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 2575c2f8865..790e5463872 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -20,6 +20,7 @@ import datasets import jinja2 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -40,7 +41,7 @@ is_apex_available, is_wandb_available, ) -from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging @@ -614,11 +615,57 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == "best": + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions. + # This can be removed once the minimum transformers version is updated to 4.47. + # Refer to https://github.com/huggingface/trl/pull/2288 for more details. + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + def create_model_card( self, model_name: Optional[str] = None, diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e491b0622ad..b36be8ffff2 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -566,7 +566,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward torch.cuda.empty_cache() diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 18066976cad..941a90e0a7d 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -482,7 +482,7 @@ def repeat_generator(): self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) + self._save_checkpoint(model, trial=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect()