From 5854497759b53d27a60ad64eafac1b24c45628a2 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 9 May 2024 14:00:01 -0400 Subject: [PATCH 1/5] visualize rm prediction --- examples/scripts/reward_modeling.py | 21 ++++++++++-------- trl/trainer/reward_trainer.py | 33 ++++++++++++++++++++++++++++- trl/trainer/utils.py | 18 +++++++++++++++- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 5f7c459632f..3acf84b761e 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -15,9 +15,9 @@ python examples/scripts/reward_modeling.py \ --model_name_or_path=facebook/opt-350m \ --output_dir="reward_modeling_anthropic_hh" \ - --per_device_train_batch_size=64 \ + --per_device_train_batch_size=1 \ --num_train_epochs=1 \ - --gradient_accumulation_steps=16 \ + --gradient_accumulation_steps=32 \ --gradient_checkpointing=True \ --learning_rate=1.41e-5 \ --report_to="wandb" \ @@ -42,8 +42,8 @@ if __name__ == "__main__": parser = HfArgumentParser((RewardConfig, ModelConfig)) - reward_config, model_config = parser.parse_args_into_dataclasses() - reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False) + config, model_config = parser.parse_args_into_dataclasses() + config.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer @@ -103,8 +103,7 @@ def preprocess_function(examples): num_proc=4, ) raw_datasets = raw_datasets.filter( - lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length - and len(x["input_ids_rejected"]) <= reward_config.max_length + lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length ) train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] @@ -115,10 +114,14 @@ def preprocess_function(examples): trainer = RewardTrainer( model=model, tokenizer=tokenizer, - args=reward_config, + args=config, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=get_peft_config(model_config), ) - trainer.train() - trainer.save_model(reward_config.output_dir) + # trainer.train() + # trainer.save_model(config.output_dir) + # trainer.push_to_hub() + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + print(metrics) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index bbee5e705e0..3afbff94784 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -13,11 +13,14 @@ # limitations under the License. import inspect import warnings +from collections import defaultdict from dataclasses import FrozenInstanceError, replace from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import pandas as pd import torch import torch.nn as nn +from accelerate.utils import gather_object from datasets import Dataset from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments from transformers.trainer_callback import TrainerCallback @@ -26,7 +29,7 @@ from ..import_utils import is_peft_available from .reward_config import RewardConfig -from .utils import RewardDataCollatorWithPadding, compute_accuracy +from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table if is_peft_available(): @@ -279,3 +282,31 @@ def prediction_step( labels = self._prepare_inputs(labels) return loss, logits, labels + + def evaluate(self, *args, **kwargs): + self.visualize_samples(sampling=True) + return super().evaluate(*args, **kwargs) + + def visualize_samples(self, sampling: bool = True, num_print_samples: int = 5): + eval_dataloader = self.get_eval_dataloader() + table = defaultdict(list) + for _, inputs in enumerate(eval_dataloader): + loss, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) + chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True) + rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True) + table["chosen_text"].extend(gather_object(chosen_text)) + table["rejected_text"].extend(gather_object(rejected_text)) + table["logits"].extend( + gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) + ) + if len(table["chosen_text"]) >= num_print_samples: + break + df = pd.DataFrame(table) + print_rich_table(pd.DataFrame(table)) + if self.accelerator.process_index == 0: + print_rich_table(df[:num_print_samples]) + if "wandb" in self.args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 31e11f84a40..9cb4d26b95b 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -18,15 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import pandas as pd import torch from accelerate import PartialState from rich.console import Console, Group from rich.live import Live from rich.panel import Panel from rich.progress import Progress +from rich.table import Table from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase +from transformers import ( + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + PreTrainedTokenizerBase, +) from transformers.trainer import TrainerCallback from transformers.trainer_utils import has_length @@ -815,3 +821,13 @@ def on_train_end(self, args, state, control, **kwargs): self.rich_console = None self.training_status = None self.current_step = None + + +def print_rich_table(df: pd.DataFrame) -> Table: + console = Console() + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.print(table) From aec12ccf4065804fcefd00b30c9408262c998ecf Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 9 May 2024 14:05:01 -0400 Subject: [PATCH 2/5] quick update --- trl/trainer/reward_trainer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 3afbff94784..8c80c5a9903 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -287,11 +287,18 @@ def evaluate(self, *args, **kwargs): self.visualize_samples(sampling=True) return super().evaluate(*args, **kwargs) - def visualize_samples(self, sampling: bool = True, num_print_samples: int = 5): + def visualize_samples(self, num_print_samples: int = 4): + """ + Visualize the reward model logits prediction + + Args: + num_print_samples (`int`, defaults to `4`): + The number of samples to print. Set to `-1` to print all samples. + """ eval_dataloader = self.get_eval_dataloader() table = defaultdict(list) for _, inputs in enumerate(eval_dataloader): - loss, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) + _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True) rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True) table["chosen_text"].extend(gather_object(chosen_text)) @@ -299,7 +306,7 @@ def visualize_samples(self, sampling: bool = True, num_print_samples: int = 5): table["logits"].extend( gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) ) - if len(table["chosen_text"]) >= num_print_samples: + if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: break df = pd.DataFrame(table) print_rich_table(pd.DataFrame(table)) From 3c15cc5ebac51044e6d8ee7d3f5c59e336dd0b28 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 9 May 2024 14:05:43 -0400 Subject: [PATCH 3/5] quick check --- examples/scripts/reward_modeling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 3acf84b761e..c18b7b115fe 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -25,6 +25,7 @@ --optim="adamw_torch" \ --logging_steps=10 \ --evaluation_strategy="steps" \ + --eval_steps=100 \ --max_length=512 \ """ import warnings @@ -119,9 +120,9 @@ def preprocess_function(examples): eval_dataset=eval_dataset, peft_config=get_peft_config(model_config), ) - # trainer.train() - # trainer.save_model(config.output_dir) - # trainer.push_to_hub() + trainer.train() + trainer.save_model(config.output_dir) + trainer.push_to_hub() metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) print(metrics) From 6132a8c79a661a2f2ea7c279236b8915f23562c8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 9 May 2024 19:03:08 +0000 Subject: [PATCH 4/5] quick fix --- examples/scripts/reward_modeling.py | 6 +++--- trl/trainer/reward_trainer.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index c18b7b115fe..9960e5f73d2 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -15,9 +15,9 @@ python examples/scripts/reward_modeling.py \ --model_name_or_path=facebook/opt-350m \ --output_dir="reward_modeling_anthropic_hh" \ - --per_device_train_batch_size=1 \ + --per_device_train_batch_size=16 \ --num_train_epochs=1 \ - --gradient_accumulation_steps=32 \ + --gradient_accumulation_steps=2 \ --gradient_checkpointing=True \ --learning_rate=1.41e-5 \ --report_to="wandb" \ @@ -25,7 +25,7 @@ --optim="adamw_torch" \ --logging_steps=10 \ --evaluation_strategy="steps" \ - --eval_steps=100 \ + --eval_steps=50 \ --max_length=512 \ """ import warnings diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 8c80c5a9903..3b98d68d900 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -284,10 +284,11 @@ def prediction_step( return loss, logits, labels def evaluate(self, *args, **kwargs): - self.visualize_samples(sampling=True) + num_print_samples = kwargs.pop("num_print_samples", 4) + self.visualize_samples(num_print_samples) return super().evaluate(*args, **kwargs) - def visualize_samples(self, num_print_samples: int = 4): + def visualize_samples(self, num_print_samples: int): """ Visualize the reward model logits prediction From 1095f7b0cf31a221efc93fffd82bf168221a25fc Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 9 May 2024 20:09:11 +0000 Subject: [PATCH 5/5] update eval steps --- examples/scripts/reward_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 9960e5f73d2..5a038d210d0 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -25,7 +25,7 @@ --optim="adamw_torch" \ --logging_steps=10 \ --evaluation_strategy="steps" \ - --eval_steps=50 \ + --eval_steps=500 \ --max_length=512 \ """ import warnings