From 94353cb9b3f94325cbf21e7a770fd0881c070c11 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 3 Jul 2024 03:01:05 +0000 Subject: [PATCH 01/92] online dpo trainer based on rloo trainer --- examples/scripts/online_dpo.py | 103 +++++ online_dpo_config.py | 37 ++ trl/trainer/online_dpo_trainer.py | 619 ++++++++++++++++++++++++++++++ 3 files changed, 759 insertions(+) create mode 100644 examples/scripts/online_dpo.py create mode 100644 online_dpo_config.py create mode 100644 trl/trainer/online_dpo_trainer.py diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py new file mode 100644 index 00000000000..b0e73e2703d --- /dev/null +++ b/examples/scripts/online_dpo.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass +from typing import Optional + +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, +) + +from trl import ModelConfig +from trl.commands.cli_utils import TrlParser +from trl.trainer.online_dpo_config import OnlineDPOConfig +from trl.trainer.online_dpo_trainer import OnlineDPOTrainer + + +@dataclass +class ScriptArguments: + dataset_name: str = None + dataset_text_field: str = "prompt" + dataset_train_split: str = "train" + dataset_test_split: Optional[str] = "test" + max_length: int = 512 + + +def prepare_dataset(dataset, tokenizer, dataset_text_field): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + remove_columns=dataset.column_names, + batched=True, + num_proc=4, # multiprocessing.cpu_count(), + load_from_cache_file=False, + ) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) + args, config, model_config = parser.parse_args_and_config() + + ################ + # Model & Tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=True, + ) + + reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1) + ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) + policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) + ################ + # Dataset + ################ + raw_datasets = load_dataset(args.dataset_name) + if config.sanity_check: + for key in raw_datasets: + raw_datasets[key] = raw_datasets[key].select(range(1024)) + config.push_to_hub = False + config.report_to = "" + config.save_strategy = "no" + config.num_sample_generations = 0 + config.total_episodes = 32 + config.per_device_train_batch_size = 8 + config.gradient_accumulation_steps = 1 + + train_dataset = raw_datasets[args.dataset_train_split] + train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field) + + if args.dataset_test_split is not None: + eval_dataset = raw_datasets[args.dataset_test_split] + eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field) + else: + eval_dataset = None + ################ + # Training + ################ + + trainer = OnlineDPOTrainer( + config=config, + tokenizer=tokenizer, + policy=policy, + ref_policy=ref_policy, + reward_model=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + + if not config.sanity_check: + trainer.save_model(config.output_dir) + if config.push_to_hub: + trainer.push_to_hub() + trainer.generate_completions() diff --git a/online_dpo_config.py b/online_dpo_config.py new file mode 100644 index 00000000000..54ca5d1d505 --- /dev/null +++ b/online_dpo_config.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Dict, Literal, Optional + +from trl.trainer.rloo_config import RLOOConfig + + +@dataclass +class OnlineDPOConfig(RLOOConfig): + save_generations: bool = False + + # DPO stuff w/o max_length which is included in RLOOConfig + beta: float = 0.1 + label_smoothing: float = 0 + loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard", "nca_pair", "robust"] = ( + "sigmoid" + ) + label_pad_token_id: int = -100 + padding_value: int = 0 + truncation_mode: str = "keep_end" + # max_length: Optional[int] = None + max_prompt_length: Optional[int] = None + max_target_length: Optional[int] = None + is_encoder_decoder: Optional[bool] = None + disable_dropout: bool = True + generate_during_eval: bool = False + precompute_ref_log_probs: bool = False + dataset_num_proc: Optional[int] = None + model_init_kwargs: Optional[Dict] = None + ref_model_init_kwargs: Optional[Dict] = None + model_adapter_name: Optional[str] = None + ref_adapter_name: Optional[str] = None + reference_free: bool = False + force_use_ref_model: bool = False + sync_ref_model: bool = False + ref_model_mixup_alpha: float = 0.9 + ref_model_sync_steps: int = 64 + rpo_alpha: Optional[float] = None diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py new file mode 100644 index 00000000000..da31f0eedf7 --- /dev/null +++ b/trl/trainer/online_dpo_trainer.py @@ -0,0 +1,619 @@ +import gc +import math +import os +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollatorWithPadding, + GenerationConfig, + PreTrainedTokenizer, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, PrinterCallback + +from trl.models.utils import unwrap_model_for_generation +from trl.trainer.rloo_trainer import INVALID_LOGPROB, RLOOTrainer +from trl.trainer.utils import ( + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + generate, + get_reward, + prepare_deepspeed, + print_rich_table, + truncate_response, +) + +from .online_dpo_config import OnlineDPOConfig + + +@dataclass +class OnlineTrainerState(TrainerState): + episode: int = 0 + + +class OnlineDPOTrainer(RLOOTrainer): + def __init__( + self, + config: OnlineDPOConfig, + tokenizer: PreTrainedTokenizer, + policy: nn.Module, + ref_policy: nn.Module, + reward_model: nn.Module, + train_dataset: Dataset, + loss_type: str = "sigmoid", + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + # model_init: Optional[Callable[[torch.nn.Module], None]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + self.args = config + args = config + self.tokenizer = tokenizer + self.policy = policy + + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = args.num_train_epochs * self.train_dataset_len + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, + args.num_mini_batches, + "`batch_size` must be a multiple of `num_mini_batches`", + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, + args.num_mini_batches, + "`local_batch_size` must be a multiple of `num_mini_batches`", + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_updates = args.total_episodes // args.batch_size + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) + + self.local_dataloader_batch_size = args.local_batch_size + + ### DPO stuff + self.beta = config.beta + self.loss_type = config.loss_type + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + self.model = policy + self.create_optimizer_and_scheduler(num_training_steps=args.num_updates) + + ######### + ### trainer specifics + ######### + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, + self.model, + self.tokenizer, + self.optimizer, + self.lr_scheduler, + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + self.backup_model = None + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(tokenizer), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.tokenizer), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, config.fp16, config.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, config.fp16, config.bf16 + ) + self.deepspeed = self.model + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + self.ref_model = self.ref_policy + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + self.model_wrapped = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + tokenizer = self.tokenizer + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + min_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + self.state.global_step = 0 + self.state.episode = 0 + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + loss_stats = torch.zeros(stats_shape, device=device) + + # approxkl_stats = torch.zeros(stats_shape, device=device) + # pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + # vf_loss_stats = torch.zeros(stats_shape, device=device) + # vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + # entropy_stats = torch.zeros(stats_shape, device=device) + # ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + self.state.max_steps = args.num_updates * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + saved_data = {"prompt": [], "chosen": [], "rejected": [], "update": []} + + for update in range(1, args.num_updates + 1): + self.state.episode += 1 * args.batch_size + self.lr_scheduler.step() + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + query_responses = [] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response, logits = generate( + unwrapped_model, + query, + tokenizer.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + logits /= args.temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + query_responses.append(query_response) + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + query_responses = torch.cat(query_responses, 0) + responses = torch.cat(responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + del (logprob, ref_logprob, score) + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) + if args.non_eos_penalty: + scores = torch.where(contain_eos_token, scores, torch.full_like(scores, args.penalty_reward_value)) + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + + kl = logprobs - ref_logprobs + non_score_reward = (-args.kl_coef * kl).sum(1) + rlhf_reward = scores + non_score_reward + + # num_examples should be same as args.local_batch_size + num_examples = scores.size(0) // 2 + first_half = scores[:num_examples] + second_half = scores[num_examples:] + + num_examples_range = torch.arange(num_examples).to(scores.device) + + chosen_indices = torch.where( + first_half >= second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples + ) + rejected_indices = torch.where( + first_half < second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples + ) + + scores_margin = scores[chosen_indices] - scores[rejected_indices] + + if self.args.save_generations: + decoded_queries = tokenizer.batch_decode(queries[:num_examples], skip_special_tokens=True) + decoded_chosen = tokenizer.batch_decode(postprocessed_responses[chosen_indices]) + decoded_rejected = tokenizer.batch_decode(postprocessed_responses[rejected_indices]) + + # WARNING, if pad token == eos token, this will remove the eos from the end + decoded_chosen = [r.split(tokenizer.pad_token)[0] for r in decoded_chosen] + decoded_rejected = [r.split(tokenizer.pad_token)[0] for r in decoded_rejected] + + saved_data["prompt"].extend(gather_object(decoded_queries)) + saved_data["chosen"].extend(gather_object(decoded_chosen)) + saved_data["rejected"].extend(gather_object(decoded_rejected)) + saved_data["update"].extend(gather_object([update for _ in range(num_examples)])) + + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.arange(args.local_batch_size) + minibatch_idx = 0 + all_chosen_rewards = [] + all_rejected_rewards = [] + all_chosen_logprobs = [] + all_rejected_logprobs = [] + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + + ## chosen + chosen_mb_inds = chosen_indices[micro_batch_inds] + chosen_responses = responses[chosen_mb_inds] + + ## rejected + rejected_mb_inds = rejected_indices[micro_batch_inds] + rejected_responses = responses[rejected_mb_inds] + + concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) + concat_query_responses = query_responses[concat_mb_inds] + + concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) + num_examples = chosen_mb_inds.shape[0] + chosen_logits = concat_output.logits[:num_examples] + rejected_logits = concat_output.logits[num_examples:] + + # chosen + chosen_logits = chosen_logits[:, context_length - 1 : -1] + chosen_logits /= args.temperature + 1e-7 + chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) + chosen_logprobs = torch.gather( + chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_logprobs = torch.masked_fill( + chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB + ) + chosen_ref_logprobs = ref_logprobs[chosen_mb_inds] + chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + # rejected + rejected_logits = rejected_logits[:, context_length - 1 : -1] + rejected_logits /= args.temperature + 1e-7 + rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) + rejected_logprobs = torch.gather( + rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_logprobs = torch.masked_fill( + rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB + ) + rejected_ref_logprobs = ref_logprobs[rejected_mb_inds] + rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) + rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( + 1 + ) + + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum).detach() + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum).detach() + + loss = losses.mean() + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() + all_chosen_rewards.append(chosen_rewards) + all_chosen_logprobs.append(chosen_logprobs_sum) + all_rejected_rewards.append(rejected_rewards) + all_rejected_logprobs.append(rejected_logprobs_sum) + # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + # ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + self.state.global_step += 1 + # del everything and empty cache + # fmt: off + del ( + loss, logits, + concat_output, concat_query_responses, + chosen_logits, rejected_logits, + chosen_logprobs, rejected_logprobs, + chosen_responses, rejected_responses, + ) + # fmt: on + torch.cuda.empty_cache() + + all_chosen_rewards = torch.cat(all_chosen_rewards, 0) + all_rejected_rewards = torch.cat(all_rejected_rewards, 0) + all_chosen_logprobs = torch.cat(all_chosen_logprobs, 0) + all_rejected_logprobs = torch.cat(all_rejected_logprobs, 0) + + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.mean() + eps = int(self.state.episode / (time.time() - start_time)) + # policy_chosen_logps = logprobs[chosen_indices] + # policy_rejected_logps = logprobs[rejected_indices] + + chosen_rewards = self.accelerator.gather(all_chosen_rewards) + chosen_logprobs = self.accelerator.gather(all_chosen_logprobs) + rejected_rewards = self.accelerator.gather(all_rejected_rewards) + rejected_logprobs = self.accelerator.gather(all_rejected_logprobs) + + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["objective/scores_margin"] = self.accelerator.gather(scores_margin.mean()).mean().item() + metrics["rewards/chosen"] = chosen_rewards.mean().item() + metrics["rewards/rejected"] = rejected_rewards.mean().item() + metrics["rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item() + metrics["rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item() + metrics["logps/rejected"] = rejected_logprobs.mean().item() + metrics["logps/chosen"] = chosen_logprobs.mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() + # metrics["logits/rejected"] = policy_rejected_logits.detach().mean().cpu() + # metrics["logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.log(metrics) + del ( + kl, + mean_kl, + mean_entropy, + scores, + scores_margin, + all_chosen_rewards, + all_chosen_logprobs, + all_rejected_rewards, + all_rejected_logprobs, + ) + + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + if self.args.save_generations: + if accelerator.is_local_main_process: + dataset = Dataset.from_dict(saved_data) + dataset.save_to_disk(os.path.join(self.args.output_dir, "online_dataset")) + + def generate_completions(self, sampling: bool = False): + args = self.args + tokenizer = self.tokenizer + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + query_response, _ = generate( + unwrapped_model, + query, + tokenizer.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) + table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + if self.accelerator.process_index == 0: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) From d88ee55dc2797398ff684c71671286f2973b719b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 5 Jul 2024 14:37:37 +0000 Subject: [PATCH 02/92] push changes --- examples/scripts/online_dpo.py | 49 +- .../trainer/online_dpo_config.py | 0 trl/trainer/online_dpo_trainer.py | 1 + trl/trainer/online_trainer.py | 469 ++++++++++++++++++ trl/trainer/ppov2_config.py | 53 +- trl/trainer/rloo_config.py | 57 +-- trl/trainer/utils.py | 38 +- 7 files changed, 553 insertions(+), 114 deletions(-) rename online_dpo_config.py => trl/trainer/online_dpo_config.py (100%) create mode 100644 trl/trainer/online_trainer.py diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index b0e73e2703d..57027b47331 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -10,8 +10,40 @@ from trl import ModelConfig from trl.commands.cli_utils import TrlParser -from trl.trainer.online_dpo_config import OnlineDPOConfig -from trl.trainer.online_dpo_trainer import OnlineDPOTrainer +from trl.trainer.online_dpo_trainer import OnlineDPOConfig, OnlineDPOTrainer +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE + + +""" +python examples/scripts/online_dpo.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-14m \ + --sft_model_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --non_eos_penalty \ + --stop_token eos \ + --response_length 53 \ + --sanity_check +python examples/scripts/online_dpo.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --non_eos_penalty \ + --stop_token eos \ + --response_length 53 \ + --sanity_check +""" @dataclass @@ -54,7 +86,9 @@ def tokenize(element): padding_side="left", trust_remote_code=True, ) - + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1) ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) @@ -65,14 +99,6 @@ def tokenize(element): if config.sanity_check: for key in raw_datasets: raw_datasets[key] = raw_datasets[key].select(range(1024)) - config.push_to_hub = False - config.report_to = "" - config.save_strategy = "no" - config.num_sample_generations = 0 - config.total_episodes = 32 - config.per_device_train_batch_size = 8 - config.gradient_accumulation_steps = 1 - train_dataset = raw_datasets[args.dataset_train_split] train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field) @@ -95,7 +121,6 @@ def tokenize(element): eval_dataset=eval_dataset, ) trainer.train() - if not config.sanity_check: trainer.save_model(config.output_dir) if config.push_to_hub: diff --git a/online_dpo_config.py b/trl/trainer/online_dpo_config.py similarity index 100% rename from online_dpo_config.py rename to trl/trainer/online_dpo_config.py diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index da31f0eedf7..0bc4c5f1744 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -412,6 +412,7 @@ def repeat_generator(): micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] ## chosen + breakpoint() chosen_mb_inds = chosen_indices[micro_batch_inds] chosen_responses = responses[chosen_mb_inds] diff --git a/trl/trainer/online_trainer.py b/trl/trainer/online_trainer.py new file mode 100644 index 00000000000..4ed603ccc9a --- /dev/null +++ b/trl/trainer/online_trainer.py @@ -0,0 +1,469 @@ +import gc +import os +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollatorWithPadding, + GenerationConfig, + PreTrainedTokenizer, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer_callback import CallbackHandler, DefaultFlowCallback + +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + generate, + get_reward, + prepare_deepspeed, + print_rich_table, + truncate_response, +) +from .rloo_config import RLOOConfig + + +INVALID_LOGPROB = 1.0 + + +class RLOOTrainer(Trainer): + def __init__( + self, + config: RLOOConfig, + tokenizer: PreTrainedTokenizer, + policy: nn.Module, + ref_policy: nn.Module, + reward_model: nn.Module, + train_dataset: Dataset, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + # model_init: Optional[Callable[[torch.nn.Module], None]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + self.args = config + args = config + self.tokenizer = tokenizer + self.policy = policy + + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.callbacks = callbacks + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_updates = args.total_episodes // args.batch_size + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) + self.local_dataloader_batch_size = exact_div( + args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k" + ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + self.model = policy + self.create_optimizer_and_scheduler(num_training_steps=args.num_updates) + + ######### + ### trainer specifics + ######### + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + DEFAULT_CALLBACKS = [DefaultFlowCallback] + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + if self.callbacks is None: + self.callbacks = default_callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) + self.control = TrainerControl() + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + self.backup_model = None + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(tokenizer), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.tokenizer), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 + ) + self.deepspeed = self.model + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + tokenizer = self.tokenizer + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + min_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + global_step = 0 + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + for update in range(1, args.num_updates + 1): + global_step += 1 * args.batch_size + self.lr_scheduler.step() + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = queries.repeat(args.rloo_k, 1) + context_length = queries.shape[1] + query_responses = [] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response, logits = generate( + unwrapped_model, + query, + tokenizer.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + + # use the logits during generation directly, instead of using the following + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + query_responses.append(query_response) + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + query_responses = torch.cat(query_responses, 0) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + del (logprob, ref_logprob, score) + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) + if args.non_eos_penalty: + scores = torch.where(contain_eos_token, scores, torch.full_like(scores, args.penalty_reward_value)) + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = (-args.kl_coef * kl).sum(1) + rlhf_reward = scores + non_score_reward + + # vectorized RLOO advantages implementation + rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1) + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1) + advantages = rlhf_reward - baseline + advantages = advantages.flatten() + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + output = forward(model, mb_query_responses, tokenizer.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + new_ratio = (new_logprobs - mb_logprobs).exp() + new_logprobs = new_logprobs.sum(1) + mb_logprobs = mb_logprobs.sum(1) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = pg_loss_max.mean() + loss = pg_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, logits, new_all_logprobs, new_logprobs, + logprobs_diff, ratio, pg_losses, pg_losses2, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + mb_advantage, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + torch.cuda.empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.mean() + eps = int(global_step / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = global_step + self.state.epoch = global_step / self.train_dataset_len # used by self.log + self.log(metrics) + del kl, mean_kl, mean_entropy, scores + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + + def generate_completions(self, sampling: bool = False): + args = self.args + tokenizer = self.tokenizer + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + query_response, _ = generate( + unwrapped_model, + query, + tokenizer.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) + table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + if self.accelerator.process_index == 0: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) diff --git a/trl/trainer/ppov2_config.py b/trl/trainer/ppov2_config.py index 3cb33d0e5c6..dfee6eb2f2f 100644 --- a/trl/trainer/ppov2_config.py +++ b/trl/trainer/ppov2_config.py @@ -1,70 +1,31 @@ -import os from dataclasses import dataclass -from typing import Literal, Optional from transformers import ( TrainingArguments, ) -from ..trainer.utils import ( - OnpolicyRuntimeConfig, -) +from ..trainer.utils import OnpolicyRuntimeConfig @dataclass class PPOv2Config(OnpolicyRuntimeConfig, TrainingArguments): - # common config - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - run_name: Optional[str] = None - """a unique name of this run""" - sanity_check: bool = False - """wether to run in debug mode""" - - # batch size related config - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - total_episodes: Optional[int] = None - """The total number of episodes in the dataset""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - num_sample_generations: int = 10 - """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" - - # other config - base_model: str = "EleutherAI/pythia-160m" - """the name of the pretrained model to use""" - response_length: int = 53 - """the length of the response""" - stop_token: Optional[Literal["eos"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the truncation token id""" - temperature: float = 0.7 - """the sampling temperature""" - penalty_reward_value: int = -1 - """the reward value for responses that do not contain `stop_token_id`""" - non_eos_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" - sft_model_path: str = "EleutherAI/pythia-160m" - """the path to the sft model""" # ppo config num_ppo_epochs: int = 4 """the number of epochs to train""" - vf_coef: float = 0.1 - """the value function coefficient""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + kl_coef: float = 0.05 + """the KL coefficient""" cliprange: float = 0.2 """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" cliprange_value: float = 0.2 """the clip range for the value function""" gamma: float = 1 """the discount factor""" lam: float = 0.95 """the lambda value for GAE""" - whiten_rewards: bool = False - """whether to whiten the rewards""" - kl_coef: float = 0.05 - """the KL coefficient""" diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 4c3c303e832..4b1d179533a 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -1,6 +1,4 @@ -import os from dataclasses import dataclass -from typing import Literal, Optional from transformers import ( TrainingArguments, @@ -9,68 +7,17 @@ from ..trainer.utils import OnpolicyRuntimeConfig -INVALID_LOGPROB = 1.0 - - @dataclass class RLOOConfig(OnpolicyRuntimeConfig, TrainingArguments): - # common config - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" - run_name: Optional[str] = None - """a unique name of this run""" - sanity_check: bool = False - """wether to run in debug mode""" - - # batch size related config - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - total_episodes: Optional[int] = None - """The total number of episodes in the dataset""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - num_sample_generations: int = 10 - """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" - - # other config - base_model: str = "EleutherAI/pythia-160m" - """the name of the pretrained model to use""" - response_length: int = 53 - """the length of the response""" - stop_token: Optional[Literal["eos"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the stop token id""" - temperature: float = 0.7 - """the sampling temperature""" - penalty_reward_value: int = -1 - """the reward value for responses that do not contain `stop_token_id`""" - non_eos_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" - reward_model_path: str = "EleutherAI/pythia-160m" - """the path to the reward model""" - sft_model_path: str = "EleutherAI/pythia-160m" - """the path to the sft model""" - # ppo config - num_mini_batches: int = 1 - """the number of minibatches to split a batch into""" num_ppo_epochs: int = 4 """the number of epochs to train""" - vf_coef: float = 0.1 - """the value function coefficient""" - cliprange: float = 0.2 - """the clip range""" - cliprange_value: float = 0.2 - """the clip range for the value function""" - gamma: float = 1 - """the discount factor""" - lam: float = 0.95 - """the lambda value for GAE""" whiten_rewards: bool = False """whether to whiten the rewards""" kl_coef: float = 0.05 """the KL coefficient""" + cliprange: float = 0.2 + """the clip range""" # rloo config rloo_k: int = 2 diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index aa44c21ce27..34f3e696231 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import random import warnings from collections import deque from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -963,6 +964,41 @@ def print_rich_table(df: pd.DataFrame) -> Table: @dataclass class OnpolicyRuntimeConfig: + # common config + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + run_name: Optional[str] = None + """a unique name of this run""" + sanity_check: bool = False + """wether to run in debug mode""" + + + # batch size related config + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + num_sample_generations: int = 10 + """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" + + # other config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: int = -1 + """the reward value for responses that do not contain `stop_token_id`""" + non_eos_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + sft_model_path: str = "EleutherAI/pythia-160m" + """the path to the sft model""" + # various batch sizes world_size: Optional[int] = None """The number of processes (GPUs) to use""" From 1a45ec4790f9ee8bf2fef8aaa0aa7949d5b4a060 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 8 Jul 2024 22:01:11 -0400 Subject: [PATCH 03/92] refactor --- trl/trainer/ppov2_config.py | 8 ++------ trl/trainer/ppov2_trainer.py | 39 ++++++++++++++++++++++-------------- trl/trainer/rloo_config.py | 8 ++------ trl/trainer/rloo_trainer.py | 24 +++++++++++++--------- trl/trainer/utils.py | 4 ++-- 5 files changed, 45 insertions(+), 38 deletions(-) diff --git a/trl/trainer/ppov2_config.py b/trl/trainer/ppov2_config.py index dfee6eb2f2f..4411a6c9de3 100644 --- a/trl/trainer/ppov2_config.py +++ b/trl/trainer/ppov2_config.py @@ -1,14 +1,10 @@ from dataclasses import dataclass -from transformers import ( - TrainingArguments, -) - -from ..trainer.utils import OnpolicyRuntimeConfig +from ..trainer.utils import OnPolicyConfig @dataclass -class PPOv2Config(OnpolicyRuntimeConfig, TrainingArguments): +class PPOv2Config(OnPolicyConfig): reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 23527923353..39c851577b7 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -23,7 +23,8 @@ TrainerState, ) from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer_callback import CallbackHandler, DefaultFlowCallback +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, PrinterCallback from ..core import masked_mean, masked_whiten from ..models.utils import unwrap_model_for_generation @@ -129,6 +130,7 @@ def __init__( self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size ######### # setup model, optimizer, and others @@ -147,14 +149,15 @@ def __init__( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) - DEFAULT_CALLBACKS = [DefaultFlowCallback] default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - if self.callbacks is None: - self.callbacks = default_callbacks + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.control = TrainerControl() + self.current_flos = 0 + self.hp_search_backend = None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # Create distant repo and output directory if needed @@ -170,7 +173,7 @@ def __init__( ######### self.dataloader = DataLoader( self.train_dataset, - batch_size=args.local_batch_size, + batch_size=self.local_dataloader_batch_size, shuffle=True, collate_fn=DataCollatorWithPadding(tokenizer), drop_last=True, # needed; otherwise the last batch will be of ragged shape @@ -191,10 +194,10 @@ def __init__( if self.is_deepspeed_enabled: self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) else: self.ref_policy = self.ref_policy.to(self.accelerator.device) @@ -280,9 +283,9 @@ def repeat_generator(): postprocessed_responses = [] logprobs = [] ref_logprobs = [] - values = [] scores = [] sequence_lengths = [] + values = [] with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): query = queries[i : i + args.local_rollout_forward_batch_size] @@ -294,7 +297,7 @@ def repeat_generator(): ) response = query_response[:, context_length:] - # use the logits during generation directly, instead of using the following + # use the logits during generation directly (which are already adjusted by temperature) all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) del logits, all_logprob @@ -332,17 +335,17 @@ def repeat_generator(): postprocessed_responses.append(postprocessed_response) logprobs.append(logprob) ref_logprobs.append(ref_logprob) - values.append(value) sequence_lengths.append(sequence_length) scores.append(score) + values.append(value) query_responses = torch.cat(query_responses, 0) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) - values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) + values = torch.cat(values, 0) del (logprob, ref_logprob, full_value, value, score, unwrapped_model) torch.cuda.empty_cache() gc.collect() @@ -356,12 +359,12 @@ def repeat_generator(): # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - sequence_lengths_p1 = sequence_lengths + 1 response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) values = torch.masked_fill(values, padding_mask_p1, 0) # 4. compute rewards @@ -404,12 +407,12 @@ def repeat_generator(): with accelerator.accumulate(model): micro_batch_end = micro_batch_start + args.per_device_train_batch_size micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_return = returns[micro_batch_inds] mb_advantage = advantages[micro_batch_inds] - mb_values = values[micro_batch_inds] mb_responses = responses[micro_batch_inds] mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id) logits = output.logits[:, context_length - 1 : -1] @@ -527,6 +530,12 @@ def repeat_generator(): ) torch.cuda.empty_cache() + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def generate_completions(self, sampling: bool = False): args = self.args tokenizer = self.tokenizer diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 4b1d179533a..a892d3d0d9c 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -1,14 +1,10 @@ from dataclasses import dataclass -from transformers import ( - TrainingArguments, -) - -from ..trainer.utils import OnpolicyRuntimeConfig +from ..trainer.utils import OnPolicyConfig @dataclass -class RLOOConfig(OnpolicyRuntimeConfig, TrainingArguments): +class RLOOConfig(OnPolicyConfig): # ppo config num_ppo_epochs: int = 4 """the number of epochs to train""" diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4ed603ccc9a..904f122d30c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -23,7 +23,8 @@ TrainerState, ) from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer_callback import CallbackHandler, DefaultFlowCallback +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, PrinterCallback from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( @@ -56,8 +57,6 @@ def __init__( eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, # less commonly used optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - # model_init: Optional[Callable[[torch.nn.Module], None]] = None, callbacks: Optional[List[TrainerCallback]] = None, ) -> None: self.args = config @@ -132,14 +131,15 @@ def __init__( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) - DEFAULT_CALLBACKS = [DefaultFlowCallback] default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - if self.callbacks is None: - self.callbacks = default_callbacks + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.control = TrainerControl() + self.current_flos = 0 + self.hp_search_backend = None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # Create distant repo and output directory if needed @@ -176,10 +176,10 @@ def __init__( if self.is_deepspeed_enabled: self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.deepspeed = self.model else: @@ -255,7 +255,7 @@ def repeat_generator(): ) response = query_response[:, context_length:] - # use the logits during generation directly, instead of using the following + # use the logits during generation directly (which are already adjusted by temperature) all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) del logits, all_logprob @@ -421,6 +421,12 @@ def repeat_generator(): if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: self.generate_completions(sampling=True) + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def generate_completions(self, sampling: bool = False): args = self.args tokenizer = self.tokenizer diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 34f3e696231..01d7b832054 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -36,6 +36,7 @@ DataCollatorForLanguageModeling, PreTrainedModel, PreTrainedTokenizerBase, + TrainingArguments, ) from transformers.trainer import TrainerCallback from transformers.trainer_utils import has_length @@ -963,7 +964,7 @@ def print_rich_table(df: pd.DataFrame) -> Table: @dataclass -class OnpolicyRuntimeConfig: +class OnPolicyConfig(TrainingArguments): # common config exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" @@ -972,7 +973,6 @@ class OnpolicyRuntimeConfig: sanity_check: bool = False """wether to run in debug mode""" - # batch size related config num_mini_batches: int = 1 """Number of minibatches to split a batch into""" From 5a3bd0589ddd4e43a049f2e17a49f0bc35196d32 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 8 Jul 2024 23:36:36 -0400 Subject: [PATCH 04/92] use `batch_generation` method --- trl/trainer/ppov2_trainer.py | 147 ++++++++++++++++++----------------- trl/trainer/rloo_config.py | 3 + trl/trainer/rloo_trainer.py | 134 +++++++++++++++---------------- trl/trainer/utils.py | 23 ++++++ 4 files changed, 170 insertions(+), 137 deletions(-) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 39c851577b7..d81dbe0eb41 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -29,11 +29,11 @@ from ..core import masked_mean, masked_whiten from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + batch_generation, disable_dropout_in_model, exact_div, first_true_indices, forward, - generate, get_reward, prepare_deepspeed, print_rich_table, @@ -278,7 +278,6 @@ def repeat_generator(): with torch.no_grad(): queries = data["input_ids"].to(device) context_length = queries.shape[1] - query_responses = [] responses = [] postprocessed_responses = [] logprobs = [] @@ -287,58 +286,59 @@ def repeat_generator(): sequence_lengths = [] values = [] with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response, logits = generate( - unwrapped_model.policy, - query, - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - - # use the logits during generation directly (which are already adjusted by temperature) - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + tokenizer.pad_token_id, + generation_config, + ) - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - unwrapped_value_model = accelerator.unwrap_model(model).value_model - full_value, _, _ = get_reward( - unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response ) - query_responses.append(query_response) - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - query_responses = torch.cat(query_responses, 0) + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) @@ -548,32 +548,35 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - query_response, _ = generate( + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( unwrapped_model.policy, query, + query.shape[0], tokenizer.pad_token_id, generation_config, ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) - table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - - if sampling: - break + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break df = pd.DataFrame(table) if self.accelerator.process_index == 0: print_rich_table(df.iloc[0 : 0 + 5]) diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index a892d3d0d9c..6dad7baef05 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -5,6 +5,9 @@ @dataclass class RLOOConfig(OnPolicyConfig): + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + # ppo config num_ppo_epochs: int = 4 """the number of epochs to train""" diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 904f122d30c..1a0f0629620 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -28,11 +28,11 @@ from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + batch_generation, disable_dropout_in_model, exact_div, first_true_indices, forward, - generate, get_reward, prepare_deepspeed, print_rich_table, @@ -245,52 +245,53 @@ def repeat_generator(): scores = [] sequence_lengths = [] with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response, logits = generate( - unwrapped_model, - query, - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - - # use the logits during generation directly (which are already adjusted by temperature) - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) + query_responses, logitss = batch_generation( + unwrapped_model, + queries, + args.local_rollout_forward_batch_size, + tokenizer.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response ) - query_responses.append(query_response) - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - query_responses = torch.cat(query_responses, 0) + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) @@ -439,32 +440,35 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - query_response, _ = generate( + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( unwrapped_model, query, + query.shape[0], tokenizer.pad_token_id, generation_config, ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) - table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - - if sampling: - break + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break df = pd.DataFrame(table) if self.accelerator.process_index == 0: print_rich_table(df.iloc[0 : 0 + 5]) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 01d7b832054..15b81e6693c 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1235,3 +1235,26 @@ def generate( ) logits = torch.stack(output.scores, 1) return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits + + +@torch.no_grad() +def batch_generation( + model: torch.nn.Module, + queries: torch.Tensor, + local_rollout_forward_batch_size: int, + pad_token_id: int, + generation_config: dict, +): + query_responses = [] + logitss = [] + for i in range(0, queries.shape[0], local_rollout_forward_batch_size): + query = queries[i : i + local_rollout_forward_batch_size] + query_response, logits = generate( + model, + query, + pad_token_id, + generation_config, + ) + query_responses.append(query_response) + logitss.append(logits) + return torch.cat(query_responses, 0), torch.cat(logitss, 0) From aa67767c31423852d2309a1ece5b59012eecf528 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 8 Jul 2024 23:38:59 -0400 Subject: [PATCH 05/92] precommit --- trl/trainer/online_dpo_config.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 54ca5d1d505..6c0f4452b06 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -1,19 +1,23 @@ from dataclasses import dataclass from typing import Dict, Literal, Optional -from trl.trainer.rloo_config import RLOOConfig +from trl.trainer.utils import OnPolicyConfig @dataclass -class OnlineDPOConfig(RLOOConfig): - save_generations: bool = False +class OnlineDPOConfig(OnPolicyConfig): + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + # ppo config + num_ppo_epochs: int = 4 + """the number of epochs to train""" # DPO stuff w/o max_length which is included in RLOOConfig - beta: float = 0.1 + beta: float = 0.05 label_smoothing: float = 0 - loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard", "nca_pair", "robust"] = ( - "sigmoid" - ) + loss_type: Literal[ + "sigmoid", "hinge", "ipo", "kto_pair", "bco_pair", "sppo_hard", "nca_pair", "robust" + ] = "sigmoid" label_pad_token_id: int = -100 padding_value: int = 0 truncation_mode: str = "keep_end" From 576c09813bd1abd37f2c56da7f50658f8c2c1d8a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 8 Jul 2024 23:56:03 -0400 Subject: [PATCH 06/92] remove breakpoint() --- trl/trainer/online_dpo_trainer.py | 115 +++++++++++------------------- 1 file changed, 41 insertions(+), 74 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 0bc4c5f1744..47b952cb73a 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -12,13 +12,14 @@ import torch.nn as nn import torch.nn.functional as F from accelerate import Accelerator -from accelerate.utils import gather_object +from accelerate.utils import broadcast, gather_object from datasets import Dataset from torch.utils.data import DataLoader from transformers import ( DataCollatorWithPadding, GenerationConfig, PreTrainedTokenizer, + Trainer, TrainerCallback, TrainerControl, TrainerState, @@ -27,9 +28,8 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, PrinterCallback -from trl.models.utils import unwrap_model_for_generation -from trl.trainer.rloo_trainer import INVALID_LOGPROB, RLOOTrainer -from trl.trainer.utils import ( +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( disable_dropout_in_model, exact_div, first_true_indices, @@ -40,16 +40,18 @@ print_rich_table, truncate_response, ) - from .online_dpo_config import OnlineDPOConfig +INVALID_LOGPROB = 1.0 + + @dataclass class OnlineTrainerState(TrainerState): episode: int = 0 -class OnlineDPOTrainer(RLOOTrainer): +class OnlineDPOTrainer(Trainer): def __init__( self, config: OnlineDPOConfig, @@ -62,12 +64,7 @@ def __init__( data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, # less commonly used - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), - # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - # model_init: Optional[Callable[[torch.nn.Module], None]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), callbacks: Optional[List[TrainerCallback]] = None, ) -> None: self.args = config @@ -92,7 +89,7 @@ def __init__( # calculate various batch sizes ######### if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = args.num_train_epochs * self.train_dataset_len + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) self.accelerator = accelerator args.world_size = accelerator.num_processes @@ -102,28 +99,19 @@ def __init__( args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) args.batch_size = int(args.local_batch_size * args.world_size) args.mini_batch_size = exact_div( - args.batch_size, - args.num_mini_batches, - "`batch_size` must be a multiple of `num_mini_batches`", + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" ) args.local_mini_batch_size = exact_div( - args.local_batch_size, - args.num_mini_batches, - "`local_batch_size` must be a multiple of `num_mini_batches`", + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_updates = args.total_episodes // args.batch_size + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) - self.local_dataloader_batch_size = args.local_batch_size - ### DPO stuff self.beta = config.beta self.loss_type = config.loss_type @@ -145,19 +133,13 @@ def __init__( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( - callbacks, - self.model, - self.tokenizer, - self.optimizer, - self.lr_scheduler, + self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.control = TrainerControl() - self.current_flos = 0 self.hp_search_backend = None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None @@ -168,7 +150,6 @@ def __init__( self.init_hf_repo() if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) - self.backup_model = None ######### @@ -176,7 +157,7 @@ def __init__( ######### self.dataloader = DataLoader( self.train_dataset, - batch_size=self.local_dataloader_batch_size, + batch_size=args.local_batch_size, shuffle=True, collate_fn=DataCollatorWithPadding(tokenizer), drop_last=True, # needed; otherwise the last batch will be of ragged shape @@ -197,17 +178,21 @@ def __init__( if self.is_deepspeed_enabled: self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, config.fp16, config.bf16 + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, config.fp16, config.bf16 + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.deepspeed = self.model else: self.ref_policy = self.ref_policy.to(self.accelerator.device) self.reward_model = self.reward_model.to(self.accelerator.device) - self.ref_model = self.ref_policy + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader def train(self): args = self.args @@ -298,7 +283,8 @@ def repeat_generator(): generation_config, ) response = query_response[:, context_length:] - logits /= args.temperature + 1e-7 + + # use the logits during generation directly (which are already adjusted by temperature) all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) del logits, all_logprob @@ -335,9 +321,9 @@ def repeat_generator(): scores.append(score) query_responses = torch.cat(query_responses, 0) responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) del (logprob, ref_logprob, score) @@ -358,11 +344,12 @@ def repeat_generator(): logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + # 4. compute rewards kl = logprobs - ref_logprobs - non_score_reward = (-args.kl_coef * kl).sum(1) + non_score_reward = (-args.beta * kl).sum(1) rlhf_reward = scores + non_score_reward - # num_examples should be same as args.local_batch_size + # num_examples should be same as args.local_batch_size divided by 2 num_examples = scores.size(0) // 2 first_half = scores[:num_examples] second_half = scores[num_examples:] @@ -377,26 +364,11 @@ def repeat_generator(): ) scores_margin = scores[chosen_indices] - scores[rejected_indices] - - if self.args.save_generations: - decoded_queries = tokenizer.batch_decode(queries[:num_examples], skip_special_tokens=True) - decoded_chosen = tokenizer.batch_decode(postprocessed_responses[chosen_indices]) - decoded_rejected = tokenizer.batch_decode(postprocessed_responses[rejected_indices]) - - # WARNING, if pad token == eos token, this will remove the eos from the end - decoded_chosen = [r.split(tokenizer.pad_token)[0] for r in decoded_chosen] - decoded_rejected = [r.split(tokenizer.pad_token)[0] for r in decoded_rejected] - - saved_data["prompt"].extend(gather_object(decoded_queries)) - saved_data["chosen"].extend(gather_object(decoded_chosen)) - saved_data["rejected"].extend(gather_object(decoded_rejected)) - saved_data["update"].extend(gather_object([update for _ in range(num_examples)])) - torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.arange(args.local_batch_size) + b_inds = np.random.permutation(args.local_batch_size) minibatch_idx = 0 all_chosen_rewards = [] all_rejected_rewards = [] @@ -412,7 +384,6 @@ def repeat_generator(): micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] ## chosen - breakpoint() chosen_mb_inds = chosen_indices[micro_batch_inds] chosen_responses = responses[chosen_mb_inds] @@ -477,14 +448,14 @@ def repeat_generator(): accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - - loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() - all_chosen_rewards.append(chosen_rewards) - all_chosen_logprobs.append(chosen_logprobs_sum) - all_rejected_rewards.append(rejected_rewards) - all_rejected_logprobs.append(rejected_logprobs_sum) - # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - # ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() + with torch.no_grad(): + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() + all_chosen_rewards.append(chosen_rewards) + all_chosen_logprobs.append(chosen_logprobs_sum) + all_rejected_rewards.append(rejected_rewards) + all_rejected_logprobs.append(rejected_logprobs_sum) + # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + # ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 self.state.global_step += 1 @@ -562,16 +533,12 @@ def repeat_generator(): if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: self.generate_completions(sampling=True) + # HF trainer specifics self.control = self.callback_handler.on_train_end(args, self.state, self.control) if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - if self.args.save_generations: - if accelerator.is_local_main_process: - dataset = Dataset.from_dict(saved_data) - dataset.save_to_disk(os.path.join(self.args.output_dir, "online_dataset")) - def generate_completions(self, sampling: bool = False): args = self.args tokenizer = self.tokenizer From c5a16125a1a2f47b999817177a7e575e532e4ab9 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 9 Jul 2024 13:36:33 -0400 Subject: [PATCH 07/92] quick refactor --- trl/trainer/online_dpo_config.py | 3 + trl/trainer/online_dpo_trainer.py | 233 ++++++++++++++---------------- trl/trainer/rloo_trainer.py | 7 - 3 files changed, 114 insertions(+), 129 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 6c0f4452b06..360dad07603 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -8,11 +8,14 @@ class OnlineDPOConfig(OnPolicyConfig): reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" + # ppo config num_ppo_epochs: int = 4 """the number of epochs to train""" # DPO stuff w/o max_length which is included in RLOOConfig + num_generation_per_prompt: int = 2 + """the number of generations per prompt (currently only support 2)""" beta: float = 0.05 label_smoothing: float = 0 loss_type: Literal[ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 47b952cb73a..4e4fab737fe 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -30,11 +30,11 @@ from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + batch_generation, disable_dropout_in_model, exact_div, first_true_indices, forward, - generate, get_reward, prepare_deepspeed, print_rich_table, @@ -60,7 +60,6 @@ def __init__( ref_policy: nn.Module, reward_model: nn.Module, train_dataset: Dataset, - loss_type: str = "sigmoid", data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, # less commonly used @@ -111,6 +110,11 @@ def __init__( self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) + self.local_dataloader_batch_size = exact_div( + args.local_batch_size, + args.num_generation_per_prompt, + "`local_batch_size` must be a multiple of `num_generation_per_prompt`", + ) # DPO logic: repeats the same prompt args.rloo_k times ### DPO stuff self.beta = config.beta @@ -157,7 +161,7 @@ def __init__( ######### self.dataloader = DataLoader( self.train_dataset, - batch_size=args.local_batch_size, + batch_size=self.local_dataloader_batch_size, shuffle=True, collate_fn=DataCollatorWithPadding(tokenizer), drop_last=True, # needed; otherwise the last batch will be of ragged shape @@ -226,13 +230,10 @@ def repeat_generator(): start_time = time.time() stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) loss_stats = torch.zeros(stats_shape, device=device) - - # approxkl_stats = torch.zeros(stats_shape, device=device) - # pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - # vf_loss_stats = torch.zeros(stats_shape, device=device) - # vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - # entropy_stats = torch.zeros(stats_shape, device=device) - # ratio_stats = torch.zeros(stats_shape, device=device) + chosen_rewards_stats = torch.zeros(stats_shape, device=device) + rejected_rewards_stats = torch.zeros(stats_shape, device=device) + chosen_logprobs_stats = torch.zeros(stats_shape, device=device) + rejected_logprobs_stats = torch.zeros(stats_shape, device=device) model.train() self.state.max_steps = args.num_updates * args.num_mini_batches self.state.num_train_epochs = args.total_episodes / self.train_dataset_len @@ -257,7 +258,6 @@ def repeat_generator(): self.state.save_steps = args.save_steps self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - saved_data = {"prompt": [], "chosen": [], "rejected": [], "update": []} for update in range(1, args.num_updates + 1): self.state.episode += 1 * args.batch_size @@ -265,6 +265,7 @@ def repeat_generator(): data = next(iter_dataloader) with torch.no_grad(): queries = data["input_ids"].to(device) + queries = queries.repeat(args.num_generation_per_prompt, 1) context_length = queries.shape[1] query_responses = [] responses = [] @@ -274,52 +275,52 @@ def repeat_generator(): scores = [] sequence_lengths = [] with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response, logits = generate( - unwrapped_model, - query, - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - - # use the logits during generation directly (which are already adjusted by temperature) - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) + query_responses, logitss = batch_generation( + unwrapped_model, + queries, + args.local_rollout_forward_batch_size, + tokenizer.pad_token_id, + generation_config, + ) - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response ) - query_responses.append(query_response) - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - query_responses = torch.cat(query_responses, 0) + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) @@ -368,21 +369,24 @@ def repeat_generator(): # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size) + b_inds = np.random.permutation(args.local_batch_size // args.num_generation_per_prompt) minibatch_idx = 0 - all_chosen_rewards = [] - all_rejected_rewards = [] - all_chosen_logprobs = [] - all_rejected_logprobs = [] - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size + for mini_batch_start in range( + 0, + args.local_batch_size // args.num_generation_per_prompt, + args.local_mini_batch_size // args.num_generation_per_prompt, + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size // args.num_generation_per_prompt mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + for micro_batch_start in range( + 0, + args.local_mini_batch_size // args.num_generation_per_prompt, + args.per_device_train_batch_size, + ): with accelerator.accumulate(model): micro_batch_end = micro_batch_start + args.per_device_train_batch_size micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - ## chosen chosen_mb_inds = chosen_indices[micro_batch_inds] chosen_responses = responses[chosen_mb_inds] @@ -393,7 +397,7 @@ def repeat_generator(): concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) concat_query_responses = query_responses[concat_mb_inds] - + # breakpoint() concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) num_examples = chosen_mb_inds.shape[0] chosen_logits = concat_output.logits[:num_examples] @@ -441,21 +445,26 @@ def repeat_generator(): else: raise NotImplementedError(f"invalid loss type {self.loss_type}") - chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum).detach() - rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum).detach() - loss = losses.mean() accelerator.backward(loss) optimizer.step() optimizer.zero_grad() with torch.no_grad(): - loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() - all_chosen_rewards.append(chosen_rewards) - all_chosen_logprobs.append(chosen_logprobs_sum) - all_rejected_rewards.append(rejected_rewards) - all_rejected_logprobs.append(rejected_logprobs_sum) - # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - # ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss + chosen_rewards_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = chosen_rewards + rejected_rewards_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = rejected_rewards + chosen_logprobs_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = chosen_logprobs_sum + rejected_logprobs_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = rejected_logprobs_sum gradient_accumulation_idx += 1 minibatch_idx += 1 self.state.global_step += 1 @@ -470,24 +479,11 @@ def repeat_generator(): ) # fmt: on torch.cuda.empty_cache() - - all_chosen_rewards = torch.cat(all_chosen_rewards, 0) - all_rejected_rewards = torch.cat(all_rejected_rewards, 0) - all_chosen_logprobs = torch.cat(all_chosen_logprobs, 0) - all_rejected_logprobs = torch.cat(all_rejected_logprobs, 0) - with torch.no_grad(): mean_kl = kl.sum(1).mean() mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.mean() eps = int(self.state.episode / (time.time() - start_time)) - # policy_chosen_logps = logprobs[chosen_indices] - # policy_rejected_logps = logprobs[rejected_indices] - - chosen_rewards = self.accelerator.gather(all_chosen_rewards) - chosen_logprobs = self.accelerator.gather(all_chosen_logprobs) - rejected_rewards = self.accelerator.gather(all_rejected_rewards) - rejected_logprobs = self.accelerator.gather(all_rejected_logprobs) metrics = {} metrics["eps"] = eps @@ -501,27 +497,17 @@ def repeat_generator(): metrics["rewards/rejected"] = rejected_rewards.mean().item() metrics["rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item() metrics["rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item() - metrics["logps/rejected"] = rejected_logprobs.mean().item() - metrics["logps/chosen"] = chosen_logprobs.mean().item() metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() - # metrics["logits/rejected"] = policy_rejected_logits.detach().mean().cpu() - # metrics["logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + metrics["train/rewards/chosen"] = self.accelerator.gather(chosen_rewards_stats).mean().item() + metrics["train/rewards/rejected"] = self.accelerator.gather(rejected_rewards_stats).mean().item() + metrics["train/logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() + metrics["train/logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.log(metrics) - del ( - kl, - mean_kl, - mean_entropy, - scores, - scores_margin, - all_chosen_rewards, - all_chosen_logprobs, - all_rejected_rewards, - all_rejected_logprobs, - ) + del (kl, mean_kl, mean_entropy, scores, scores_margin) self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: @@ -551,32 +537,35 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - query_response, _ = generate( + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( unwrapped_model, query, + query.shape[0], tokenizer.pad_token_id, generation_config, ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) - table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - if sampling: - break + if sampling: + break df = pd.DataFrame(table) if self.accelerator.process_index == 0: print_rich_table(df.iloc[0 : 0 + 5]) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 1a0f0629620..e1d046a6eec 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -97,12 +97,6 @@ def __init__( args.local_mini_batch_size = exact_div( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_updates = args.total_episodes // args.batch_size time_tensor = torch.tensor(int(time.time()), device=accelerator.device) time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes @@ -258,7 +252,6 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) del logits, all_logprob From e264126d919924473e7a40461a5a58589730d118 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 11 Jul 2024 13:39:16 +0000 Subject: [PATCH 08/92] push the current changes --- examples/scripts/online_dpo.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 57027b47331..4b5c2c2032a 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -29,20 +29,24 @@ --stop_token eos \ --response_length 53 \ --sanity_check -python examples/scripts/online_dpo.py \ +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/online_dpo.py \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ - --total_episodes 30000 \ + --per_device_train_batch_size 16 \ + --local_rollout_forward_batch_size 32 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ --model_name_or_path EleutherAI/pythia-1b-deduped \ --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --save_strategy no \ --non_eos_penalty \ --stop_token eos \ + --beta 0.1 \ --response_length 53 \ - --sanity_check + --push_to_hub """ @@ -51,7 +55,7 @@ class ScriptArguments: dataset_name: str = None dataset_text_field: str = "prompt" dataset_train_split: str = "train" - dataset_test_split: Optional[str] = "test" + dataset_test_split: Optional[str] = "validation" max_length: int = 512 From 9653edb449e9ee0b784ff483fe707245ef5eca00 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 15 Jul 2024 15:27:26 +0000 Subject: [PATCH 09/92] quick change --- docs/source/ppov2_trainer.md | 13 ++++++++++++- docs/source/rloo_trainer.md | 4 ++-- examples/scripts/online_dpo.py | 4 +++- trl/trainer/online_dpo_trainer.py | 8 ++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/docs/source/ppov2_trainer.md b/docs/source/ppov2_trainer.md index f9c8aaa58d1..98c7dc37242 100644 --- a/docs/source/ppov2_trainer.md +++ b/docs/source/ppov2_trainer.md @@ -199,7 +199,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --model_name_or_path EleutherAI/pythia-6.9b-deduped \ --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ - --local_rollout_forward_batch_size 2 \ + --local_rollout_forward_batch_size 4 \ --non_eos_penalty \ --stop_token eos \ ``` @@ -229,7 +229,18 @@ python -i examples/scripts/evals/generate_tldr.py \ # response1 472 # Name: count, dtype: int64 ``` +import matplotlib.pyplot as plt +ys = [34.4, 53.2, 52.8] +xs = ["SFT policy", "RLOO policy 1B", "PPO Policy 1B"] + +plt.bar(xs, ys) +plt.ylabel('Win rate against reference summaries') +plt.xlabel('Model Name') +plt.title('Win Rate Comparison') + +plt.show() +``` The PPO checkpoint gets a 52.8% preferred rate vs the 34.4% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended. diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 1e47b5b0196..eed8fb0f59a 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -218,8 +218,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --num_ppo_epochs 2 \ --num_mini_batches 2 \ --learning_rate 3e-6 \ - --per_device_train_batch_size 16 \ - --gradient_accumulation_steps 16 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 8 \ --total_episodes 1000000 \ --model_name_or_path EleutherAI/pythia-1b-deduped \ --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 4b5c2c2032a..b0ca3cbce5a 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -36,9 +36,11 @@ --output_dir models/minimal/online_dpo \ --per_device_train_batch_size 16 \ --local_rollout_forward_batch_size 32 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ --gradient_accumulation_steps 4 \ --total_episodes 1000000 \ - --model_name_or_path EleutherAI/pythia-1b-deduped \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --save_strategy no \ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 4e4fab737fe..f588a82ba99 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -455,16 +455,16 @@ def repeat_generator(): loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss chosen_rewards_stats[ ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_rewards + ] = chosen_rewards.mean() rejected_rewards_stats[ ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_rewards + ] = rejected_rewards.mean() chosen_logprobs_stats[ ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_logprobs_sum + ] = chosen_logprobs_sum.mean() rejected_logprobs_stats[ ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_logprobs_sum + ] = rejected_logprobs_sum.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 self.state.global_step += 1 From 798d1d6feec92ac2d9111e246fbd41baec6a6ff5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 15 Jul 2024 15:34:56 +0000 Subject: [PATCH 10/92] refactor --- trl/trainer/online_dpo_trainer.py | 19 ++++--------- trl/trainer/ppov2_trainer.py | 45 ++++++++++++++++++++++++------- trl/trainer/rloo_trainer.py | 45 +++++++++++++++++++++++++------ trl/trainer/utils.py | 6 +++++ 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f588a82ba99..f137f9aa6bf 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -39,6 +39,7 @@ prepare_deepspeed, print_rich_table, truncate_response, + OnlineTrainerState, ) from .online_dpo_config import OnlineDPOConfig @@ -46,11 +47,6 @@ INVALID_LOGPROB = 1.0 -@dataclass -class OnlineTrainerState(TrainerState): - episode: int = 0 - - class OnlineDPOTrainer(Trainer): def __init__( self, @@ -225,8 +221,6 @@ def repeat_generator(): ) accelerator.print("===training policy===") - self.state.global_step = 0 - self.state.episode = 0 start_time = time.time() stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) loss_stats = torch.zeros(stats_shape, device=device) @@ -235,11 +229,12 @@ def repeat_generator(): chosen_logprobs_stats = torch.zeros(stats_shape, device=device) rejected_logprobs_stats = torch.zeros(stats_shape, device=device) model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 self.state.max_steps = args.num_updates * args.num_mini_batches self.state.num_train_epochs = args.total_episodes / self.train_dataset_len - self.state.is_local_process_zero = self.is_local_process_zero() - self.state.is_world_process_zero = self.is_world_process_zero() - # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: if args.logging_steps < 1: @@ -256,7 +251,6 @@ def repeat_generator(): self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) else: self.state.save_steps = args.save_steps - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) for update in range(1, args.num_updates + 1): @@ -267,7 +261,6 @@ def repeat_generator(): queries = data["input_ids"].to(device) queries = queries.repeat(args.num_generation_per_prompt, 1) context_length = queries.shape[1] - query_responses = [] responses = [] postprocessed_responses = [] logprobs = [] @@ -397,7 +390,6 @@ def repeat_generator(): concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) concat_query_responses = query_responses[concat_mb_inds] - # breakpoint() concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) num_examples = chosen_mb_inds.shape[0] chosen_logits = concat_output.logits[:num_examples] @@ -484,7 +476,6 @@ def repeat_generator(): mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.mean() eps = int(self.state.episode / (time.time() - start_time)) - metrics = {} metrics["eps"] = eps metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index d81dbe0eb41..1c50bcc1320 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -1,4 +1,5 @@ import gc +import math import os import time from collections import OrderedDict, defaultdict @@ -20,7 +21,6 @@ Trainer, TrainerCallback, TrainerControl, - TrainerState, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -29,6 +29,7 @@ from ..core import masked_mean, masked_whiten from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + OnlineTrainerState, batch_generation, disable_dropout_in_model, exact_div, @@ -96,7 +97,6 @@ def __init__( self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers - self.callbacks = callbacks ######### # calculate various batch sizes @@ -145,7 +145,7 @@ def __init__( ######### ### trainer specifics ######### - self.state = TrainerState( + self.state = OnlineTrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) @@ -260,7 +260,6 @@ def repeat_generator(): ) accelerator.print("===training policy===") - global_step = 0 start_time = time.time() stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) @@ -271,8 +270,32 @@ def repeat_generator(): entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_updates * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + for update in range(1, args.num_updates + 1): - global_step += 1 * args.batch_size + self.state.episode += 1 * args.batch_size self.lr_scheduler.step() data = next(iter_dataloader) with torch.no_grad(): @@ -299,7 +322,6 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) del logits, all_logprob @@ -481,7 +503,7 @@ def repeat_generator(): mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.sum(1).mean() rlhf_reward = mean_non_score_reward + scores.mean() - eps = int(global_step / (time.time() - start_time)) + eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() @@ -499,10 +521,15 @@ def repeat_generator(): metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = global_step - self.state.epoch = global_step / self.train_dataset_len # used by self.log + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.log(metrics) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect() diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e1d046a6eec..e9ce5b44cd9 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -20,7 +20,6 @@ Trainer, TrainerCallback, TrainerControl, - TrainerState, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -28,6 +27,7 @@ from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + OnlineTrainerState, batch_generation, disable_dropout_in_model, exact_div, @@ -76,7 +76,6 @@ def __init__( self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers - self.callbacks = callbacks ######### # calculate various batch sizes @@ -121,7 +120,7 @@ def __init__( ######### ### trainer specifics ######### - self.state = TrainerState( + self.state = OnlineTrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) @@ -191,6 +190,7 @@ def train(self): accelerator = self.accelerator optimizer = self.optimizer model = self.model + self.model_wrapped = self.model ref_policy = self.ref_policy reward_model = self.reward_model tokenizer = self.tokenizer @@ -212,7 +212,6 @@ def repeat_generator(): ) accelerator.print("===training policy===") - global_step = 0 start_time = time.time() stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) @@ -223,8 +222,32 @@ def repeat_generator(): entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_updates * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + for update in range(1, args.num_updates + 1): - global_step += 1 * args.batch_size + self.state.episode += 1 * args.batch_size self.lr_scheduler.step() data = next(iter_dataloader) with torch.no_grad(): @@ -373,6 +396,7 @@ def repeat_generator(): ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 + self.state.global_step += 1 # del everything and empty cache # fmt: off del ( @@ -387,7 +411,7 @@ def repeat_generator(): mean_kl = kl.sum(1).mean() mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.mean() - eps = int(global_step / (time.time() - start_time)) + eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() @@ -405,10 +429,15 @@ def repeat_generator(): metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = global_step - self.state.epoch = global_step / self.train_dataset_len # used by self.log + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.log(metrics) del kl, mean_kl, mean_entropy, scores + + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 15b81e6693c..d29193448ec 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -37,6 +37,7 @@ PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments, + TrainerState, ) from transformers.trainer import TrainerCallback from transformers.trainer_utils import has_length @@ -963,6 +964,11 @@ def print_rich_table(df: pd.DataFrame) -> Table: # is to have the generated response to end with an EOS token, but the query itself should not end with EOS tokens. +@dataclass +class OnlineTrainerState(TrainerState): + episode: int = 0 + + @dataclass class OnPolicyConfig(TrainingArguments): # common config From 6562bc25d45321095f477ca378051335f681b057 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 15 Jul 2024 15:44:19 +0000 Subject: [PATCH 11/92] use the config name as the experiment name --- trl/trainer/online_dpo_config.py | 3 +++ trl/trainer/ppov2_config.py | 3 +++ trl/trainer/utils.py | 2 -- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 360dad07603..b2ca4fdd6ae 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import Dict, Literal, Optional @@ -6,6 +7,8 @@ @dataclass class OnlineDPOConfig(OnPolicyConfig): + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" diff --git a/trl/trainer/ppov2_config.py b/trl/trainer/ppov2_config.py index 4411a6c9de3..05247f4fae1 100644 --- a/trl/trainer/ppov2_config.py +++ b/trl/trainer/ppov2_config.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from ..trainer.utils import OnPolicyConfig @@ -5,6 +6,8 @@ @dataclass class PPOv2Config(OnPolicyConfig): + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d29193448ec..c857378a4a4 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -972,8 +972,6 @@ class OnlineTrainerState(TrainerState): @dataclass class OnPolicyConfig(TrainingArguments): # common config - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" run_name: Optional[str] = None """a unique name of this run""" sanity_check: bool = False From 7a0c2739bb67f62bcd3e0c1435ab65002d800a22 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 16 Jul 2024 14:18:41 +0000 Subject: [PATCH 12/92] fix logging --- trl/trainer/online_dpo_trainer.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f137f9aa6bf..315523696f7 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -3,7 +3,6 @@ import os import time from collections import defaultdict -from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -22,7 +21,6 @@ Trainer, TrainerCallback, TrainerControl, - TrainerState, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -30,6 +28,7 @@ from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( + OnlineTrainerState, batch_generation, disable_dropout_in_model, exact_div, @@ -39,7 +38,6 @@ prepare_deepspeed, print_rich_table, truncate_response, - OnlineTrainerState, ) from .online_dpo_config import OnlineDPOConfig @@ -476,6 +474,8 @@ def repeat_generator(): mean_entropy = (-logprobs).sum(1).mean() mean_non_score_reward = non_score_reward.mean() eps = int(self.state.episode / (time.time() - start_time)) + g_chosen_reward = self.accelerator.gather(chosen_rewards_stats) + g_rejected_reward = self.accelerator.gather(rejected_rewards_stats) metrics = {} metrics["eps"] = eps metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() @@ -484,15 +484,13 @@ def repeat_generator(): metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() metrics["objective/scores_margin"] = self.accelerator.gather(scores_margin.mean()).mean().item() - metrics["rewards/chosen"] = chosen_rewards.mean().item() - metrics["rewards/rejected"] = rejected_rewards.mean().item() - metrics["rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item() - metrics["rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item() + metrics["rewards/chosen"] = g_chosen_reward.mean().item() + metrics["rewards/rejected"] = g_rejected_reward.mean().item() + metrics["rewards/accuracies"] = (g_chosen_reward > g_rejected_reward).float().mean().item() + metrics["rewards/margins"] = (g_chosen_reward - g_rejected_reward).mean().item() metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() - metrics["train/rewards/chosen"] = self.accelerator.gather(chosen_rewards_stats).mean().item() - metrics["train/rewards/rejected"] = self.accelerator.gather(rejected_rewards_stats).mean().item() - metrics["train/logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() - metrics["train/logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() + metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() + metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode From 7e03124746bc26081393be48b9c1d9ee438825df Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 16 Jul 2024 16:29:09 +0000 Subject: [PATCH 13/92] update online DPO docs --- docs/source/online_dpo_trainer.md | 295 ++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 docs/source/online_dpo_trainer.md diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md new file mode 100644 index 00000000000..ce986fd3018 --- /dev/null +++ b/docs/source/online_dpo_trainer.md @@ -0,0 +1,295 @@ +# Online DPO Trainer + +TRL supports training LLMs with online DPO ([Guo et al., 2024](https://arxiv.org/abs/2402.04792)) with a reward model (RM). The idea of online DPO is to generate completions based on prompts and either have an RM or a LLM judge to rank the responses. Then the policy is updated with the ranked responses using the DPO loss. + +While [Guo et al. (2024)](https://arxiv.org/abs/2402.04792) used a LLM judge, in this implementation we just used a RM. + + +## Get started + +To just run the online DPO script to make sure the trainer can run, you can run the following command to train an online DPO model with a dummy reward model. + +```bash +python examples/scripts/online_dpo.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-14m \ + --sft_model_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --non_eos_penalty \ + --stop_token eos \ + --response_length 53 \ + --sanity_check +``` + + +## Explanation of the logged metrics + +The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + +* `eps`: Tracks the number of episodes per second. +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. +* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. +* `objective/scores`: The mean scores returned by the reward model / environment. +* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. +* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions. +* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions. +* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the rejected completions. +* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `lr`: lr: The current learning rate used by the optimizer. +* `episode`: episode: The current global step or episode count in the training process. + + +## Cookbook + +* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. +* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. +* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. +* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. + + +## What is my model doing exactly? + +To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif?download=true) + + +In the logs the sampled generations look like + +``` +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ +┃ query ┃ model response ┃ score ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ +│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │ +│ │ I don't know how to get rid of │ │ +│ TITLE: How do you get someone │ those feelings. I'm │ │ +│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │ +│ │ │ │ +│ POST: Hi, │ │ │ +│ I'm 22, and I have been with my │ │ │ +│ girlfriend for 5 years now. We │ │ │ +│ recently moved together. We've │ │ │ +│ always loved each other │ │ │ +│ intensely. │ │ │ +│ │ │ │ +│ Problem, I recently started to │ │ │ +│ have feelings for an other │ │ │ +│ person (a friend). This person │ │ │ +│ has had a boyfriend for now 3 │ │ │ +│ years, and has absolutely no │ │ │ +│ ideas. Those feelings were so │ │ │ +│ strong, it was hard to hide │ │ │ +│ them. After 2 months of me │ │ │ +│ being distant and really sad, │ │ │ +│ my girlfriend forced me to say │ │ │ +│ what was bothering me. I'm not │ │ │ +│ a good liar, and now she knows. │ │ │ +│ │ │ │ +│ We decided to give us a week │ │ │ +│ alone, I went to my parents. │ │ │ +│ │ │ │ +│ Now, I'm completely lost. I │ │ │ +│ keep on thinking about this │ │ │ +│ person, and I hate that. I │ │ │ +│ would like for those feelings │ │ │ +│ to go away, to leave me alone. │ │ │ +│ But I can't. │ │ │ +│ │ │ │ +│ What do I do? It's been 3 │ │ │ +│ months now, and I'm just │ │ │ +│ desperate. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +├─────────────────────────────────┼─────────────────────────────────┼──────────┤ +│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │ +│ │ TV. I blasted Gangnam Style on │ │ +│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │ +│ with a loud TV. │ up as high as it could │ │ +│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │ +│ POST: She was in her living │ │ │ +│ room, watching TV. This was at │ │ │ +│ about 8:30 in the morning, and │ │ │ +│ she was exercising. She turned │ │ │ +│ the TV up extra loud to hear it │ │ │ +│ over her excercycle, and woke │ │ │ +│ me up. I went in there asking │ │ │ +│ for her to turn it down. She │ │ │ +│ said she didn't have to; I │ │ │ +│ explained that I always used │ │ │ +│ headphones so she didn't have │ │ │ +│ to deal with my noise and that │ │ │ +│ she should give me a little │ │ │ +│ more respect, given that I paid │ │ │ +│ rent at the time. │ │ │ +│ │ │ │ +│ She disagreed. I went back to │ │ │ +│ my room, rather pissed off at │ │ │ +│ the lack of equality. I had no │ │ │ +│ lock on my door; but I had a │ │ │ +│ dresser right next to it, so I │ │ │ +│ pulled one of the drawers out │ │ │ +│ enough so that it caused the │ │ │ +│ door to not be openable. Then, │ │ │ +│ I turned my speakers up really │ │ │ +│ loud and blasted Gangnam Style │ │ │ +│ on repeat, with the bass │ │ │ +│ cranked up as high as it could │ │ │ +│ go. │ │ │ +│ │ │ │ +│ If you hate Gangnam Style for │ │ │ +│ being overplayed, you will see │ │ │ +│ why I chose that particular │ │ │ +│ song. I personally don't mind │ │ │ +│ it. But here's the thing about │ │ │ +│ my bass; it vibrates the walls, │ │ │ +│ making one hell of a lot of │ │ │ +│ noise. Needless to say, my mom │ │ │ +│ was not pleased and shut off │ │ │ +│ the internet. But it was oh so │ │ │ +│ worth it. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +├─────────────────────────────────┼─────────────────────────────────┼──────────┤ +``` + +## Implementation details + +Many online implementation details are borrowed from the PPOv2Trainer, which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://arxiv.org/pdf/2403.17031). Here are some additional implementation details: + +1. When we turn on the EOS trick (i.e., replacing the score of completions that do not end with an EOS token with a scalar penalty score like `-1`) via `--non_eos_penalty --stop_token eos`, it's possible that the chosen and rejected completions have the same score. In this case, we will naively select the completion with the lower index and the chosen completion. + +## Benchmark experiments + +To validate the online DPO implementation works, we ran experiments on the 1B and 6.9B models. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://arxiv.org/pdf/2403.17031). + + +``` +# 1B PPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/online_dpo.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_tldr \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ + --local_rollout_forward_batch_size 32 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 1000000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --save_strategy no \ + --non_eos_penalty \ + --stop_token eos \ + --beta 0.1 \ + --response_length 53 \ + --push_to_hub + +# 6.9B PPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/online_dpo.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_tldr_6.9b \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 16 \ + --local_rollout_forward_batch_size 8 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-6.9b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ + --save_strategy no \ + --non_eos_penalty \ + --stop_token eos \ + --beta 0.1 \ + --response_length 53 \ + --push_to_hub +``` + +1B experiment can be found here: + +- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + + +To evaluate, we use vLLM to load the checkpoints and GPT3.5 as a judge model to evaluate the generated TL;DR against the reference TL;DR. +```bash +#### using GPT4 as a judge +python -i examples/scripts/evals/generate_tldr.py \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --judge_model gpt-4-0613 \ + --output_path examples/scripts/evals/sft_tldr.csv \ + --n 1000 +# preferred +# response1 790 +# response0 210 +# Name: count, dtype: int64 +python -i examples/scripts/evals/generate_tldr.py \ + --model_name_or_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ + --judge_model gpt-4-0613 \ + --output_path examples/scripts/evals/sft_tldr.csv \ + --n 1000 +# preferred +# response1 691 +# response0 309 +# Name: count, dtype: int64 +python -i examples/scripts/evals/generate_tldr.py \ + --model_name_or_path vwxyzjn/online_dpo_tldr \ + --judge_model gpt-4-0613 \ + --output_path examples/scripts/evals/online_dpo_tldr.csv \ + --n 1000 +# preferred +# response0 532 +# response1 468 +# Name: count, dtype: int64 +python -i examples/scripts/evals/generate_tldr.py \ + --model_name_or_path vwxyzjn/online_dpo_tldr_6.9b \ + --judge_model gpt-4-0613 \ + --output_path examples/scripts/evals/online_dpo_tldr_6.9b.csv \ + --n 1000 +# preferred +# response0 780 +# response1 220 +# Name: count, dtype: int64 +``` + +We can then plot the RLHF scaling chart. + +```python +import matplotlib.pyplot as plt +data = { + "SFT": [[1e9, 6.9e9], [210 / 1000, 309 / 1000]], + "online DPO": [[1e9, 6.9e9], [532 / 1000, 780 / 1000]], +} +for model, (x, y) in data.items(): + plt.scatter(x, y, label=model) +plt.axhline(y=0.5, color="black", linestyle="-.", label="human reference summary") +plt.title("RLHF scaling by model size") +plt.xlabel("Model size") +plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)") +plt.xscale("log") +plt.xlim(5e8, 1e10) +plt.legend() +plt.grid(True, which="both", ls="--", c="0.7") +plt.tight_layout() +plt.savefig("plot.png") +``` + + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online%20DPO%20scaling.png?download=true) + +The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended. + From 0f8b1e31ca9b75ac5b70daff76c575e6aa40e0c7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 17 Jul 2024 03:03:53 +0000 Subject: [PATCH 14/92] use llm as a judge --- examples/scripts/online_dpo_llmjudge.py | 276 ++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 examples/scripts/online_dpo_llmjudge.py diff --git a/examples/scripts/online_dpo_llmjudge.py b/examples/scripts/online_dpo_llmjudge.py new file mode 100644 index 00000000000..3a601f1084c --- /dev/null +++ b/examples/scripts/online_dpo_llmjudge.py @@ -0,0 +1,276 @@ +import asyncio +import random +import time +from dataclasses import dataclass +from typing import Optional + +import pandas as pd +from datasets import load_dataset +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, +) + +from trl import ModelConfig +from trl.commands.cli_utils import TrlParser +from trl.trainer.online_dpo_trainer import OnlineDPOConfig, OnlineDPOTrainer +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE + + +""" +python examples/scripts/online_dpo_llmjudge.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_llmjudge \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-14m \ + --sft_model_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --non_eos_penalty \ + --stop_token eos \ + --response_length 53 \ + --sanity_check +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/online_dpo_llmjudge.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_llmjudge \ + --per_device_train_batch_size 16 \ + --local_rollout_forward_batch_size 32 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --save_strategy no \ + --non_eos_penalty \ + --stop_token eos \ + --beta 0.1 \ + --response_length 53 \ + --push_to_hub +python \ + examples/scripts/online_dpo_llmjudge.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_llmjudge_tldr \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 16 \ + --local_rollout_forward_batch_size 32 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 1000000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --save_strategy no \ + --non_eos_penalty \ + --stop_token eos \ + --beta 0.1 \ + --response_length 53 \ + --push_to_hub +""" + + +@dataclass +class ScriptArguments: + dataset_name: str = None + dataset_text_field: str = "prompt" + dataset_train_split: str = "train" + dataset_test_split: Optional[str] = "validation" + max_length: int = 512 + + +def prepare_dataset(dataset, tokenizer, dataset_text_field): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + remove_columns=dataset.column_names, + batched=True, + num_proc=4, # multiprocessing.cpu_count(), + load_from_cache_file=False, + ) + + +TEMPLATE = r""" +Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence. + +### Post: +{{post}} + +### Summary A: +{{response0}} + +### Summary B: +{{response1}} + +### Instructions: +FIRST provide a one-sentence comparison of the two summaries, explaining which \ +you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format: +Comparison: +Preferred: <"A" or "B"> +""" + +# 1. extract the common text as the query (this way we do not require the user to provide the query) +# 2. support different kind of judges (let us say mainly LLM judges at the moment) + + +@dataclass +class LLMJudgeConfig: + n: int = 64 + model: str = "gpt-3.5-turbo-0125" + max_parallel_requests: Optional[int] = None + llm_judge_template: str = "" + + def __post_init__(self): + if "gpt-3.5" in self.model: + # gpt-3.5 generates so fast that it will exceeds the + # token limit per minute + self.max_parallel_requests = 11 + elif "gpt-4" in self.model: + self.max_parallel_requests = 13 + + +class LLMJudge: + def __init__(self, ljc: LLMJudgeConfig): + self.ljc = ljc + self.async_client = AsyncOpenAI() + + async def process_text(self, post: str, response0: str, response1: str, i: int, limiter=None): + text = self.ljc.llm_judge_template.replace("{{post}}", post) + text = text.replace("{{response0}}", response0) + text = text.replace("{{response1}}", response1) + + async with limiter: + response = None + while response is None: + try: + response = await self.async_client.chat.completions.create( + model=self.ljc.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": text}, + ], + ) + r = response.choices[0].message.content + except Exception as e: + print(f"error in {i}: {e}") + time.sleep(30) + continue + + try: + comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip() + preferred = r.split("Preferred:")[1].strip() + return comparison, preferred, i, text + r + except Exception as e: + print(f"error in {i} {e}") + return "", random.choice(["A", "B"]), i, text + r + + def judge(self, df: pd.DataFrame): + async def main(ljc: LLMJudgeConfig, df: pd.DataFrame): + limiter = asyncio.Semaphore(ljc.max_parallel_requests) + """`df` should have columns: `prompt`, `response0`, `response1`""" + tasks = [] + df["explanation"] = [None for _ in range(len(df))] + df["preferred"] = [None for _ in range(len(df))] + df["shuffled_index"] = [None for _ in range(len(df))] + df["entire_conversation"] = [None for _ in range(len(df))] + r = range(min(ljc.n, len(df))) + if ljc.n == -1: + r = range(len(df)) + for i in r: + post = df["prompt"].iloc[i].strip() + # shuffled the index to avoid GPT4's preference bias in the content's order + shuffled_index = random.randint(0, 1) + df.at[i, "shuffled_index"] = shuffled_index + responses = [ + df["response0"].iloc[i].strip(), + df["response1"].iloc[i].strip(), + ] + response0 = responses[shuffled_index] + response1 = responses[1 - shuffled_index] + task = asyncio.create_task(self.process_text(post, response0, response1, i, limiter)) + tasks.append(task) + + results = await tqdm_asyncio.gather(*tasks) + + for _, (comparison, preferred, i, entire_conversation) in enumerate(results): + df.at[i, "explanation"] = comparison + df.at[i, "entire_conversation"] = entire_conversation + preferred_label = ( + "response0" + if (df.at[i, "shuffled_index"] == 0 and preferred == "A") + or (df.at[i, "shuffled_index"] == 1 and preferred == "B") + else "response1" + ) + df.at[i, "preferred"] = preferred_label + return df + + return asyncio.run(main(self.ljc, df)) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) + args, config, model_config = parser.parse_args_and_config() + + ################ + # Model & Tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=True, + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE + judge = LLMJudge(LLMJudgeConfig(n=-1, max_parallel_requests=20, llm_judge_template=TEMPLATE)) + ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) + policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) + ################ + # Dataset + ################ + raw_datasets = load_dataset(args.dataset_name) + if config.sanity_check: + for key in raw_datasets: + raw_datasets[key] = raw_datasets[key].select(range(1024)) + train_dataset = raw_datasets[args.dataset_train_split] + train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field) + + if args.dataset_test_split is not None: + eval_dataset = raw_datasets[args.dataset_test_split] + eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field) + else: + eval_dataset = None + ################ + # Training + ################ + + trainer = OnlineDPOTrainer( + config=config, + tokenizer=tokenizer, + policy=policy, + ref_policy=ref_policy, + judge=judge, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + if not config.sanity_check: + trainer.save_model(config.output_dir) + if config.push_to_hub: + trainer.push_to_hub() + trainer.generate_completions() From 1f0f6b2db91cacaa2f5b8082cc39240a2f8d7203 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 17 Jul 2024 03:04:38 +0000 Subject: [PATCH 15/92] quick change --- trl/trainer/online_dpo_trainer.py | 68 ++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 315523696f7..aec6df1b41d 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -3,7 +3,7 @@ import os import time from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -52,8 +52,9 @@ def __init__( tokenizer: PreTrainedTokenizer, policy: nn.Module, ref_policy: nn.Module, - reward_model: nn.Module, train_dataset: Dataset, + reward_model: Optional[nn.Module] = None, + judge: Optional[Any] = None, data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, # less commonly used @@ -71,9 +72,10 @@ def __init__( self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding self.ref_policy = ref_policy - self.reward_model = reward_model self.train_dataset = train_dataset self.train_dataset_len = len(train_dataset) + self.reward_model = reward_model + self.judge = judge self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers @@ -117,7 +119,10 @@ def __init__( ######### # setup model, optimizer, and others ######### - for module in [policy, ref_policy, reward_model]: + setup_dropout_models = [policy, ref_policy] + if reward_model is not None: + setup_dropout_models.append(reward_model) + for module in setup_dropout_models: disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = tokenizer.eos_token_id @@ -175,16 +180,18 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) + if reward_model is not None: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) self.ref_policy = prepare_deepspeed( self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 ) self.deepspeed = self.model else: self.ref_policy = self.ref_policy.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) + if reward_model is not None: + self.reward_model = self.reward_model.to(self.accelerator.device) def get_train_dataloader(self) -> DataLoader: return self.dataloader @@ -302,23 +309,43 @@ def repeat_generator(): # Response Processing 2. run reward model on the truncated responses postprocessed_query_response = torch.cat((query, postprocessed_response), 1) sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) + if reward_model is not None: + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) responses.append(response) postprocessed_responses.append(postprocessed_response) logprobs.append(logprob) ref_logprobs.append(ref_logprob) sequence_lengths.append(sequence_length) - scores.append(score) + # scores.append(score) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - del (logprob, ref_logprob, score) + # scores = torch.cat(scores, 0) + + scores = torch.zeros(args.local_batch_size).to(device) + num_examples = postprocessed_responses.size(0) // 2 + if self.judge is not None: + df = pd.DataFrame( + { + "prompt": tokenizer.batch_decode(query[:num_examples], skip_special_tokens=True), + "response0": tokenizer.batch_decode( + postprocessed_responses[:num_examples], skip_special_tokens=True + ), + "response1": tokenizer.batch_decode( + postprocessed_responses[num_examples:], skip_special_tokens=True + ), + } + ) + judge_df = self.judge.judge(df) + scores[:num_examples] = torch.tensor(judge_df["preferred"] == "response0", dtype=torch.float) + scores[num_examples:] = torch.tensor(judge_df["preferred"] == "response1", dtype=torch.float) + + del (logprob, ref_logprob) torch.cuda.empty_cache() gc.collect() @@ -328,7 +355,9 @@ def repeat_generator(): contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) if args.non_eos_penalty: scores = torch.where(contain_eos_token, scores, torch.full_like(scores, args.penalty_reward_value)) - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + accelerator.print( + f"{scores.shape, scores.sum()=}, {(contain_eos_token.sum() / len(contain_eos_token))=}" + ) # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) @@ -342,7 +371,6 @@ def repeat_generator(): rlhf_reward = scores + non_score_reward # num_examples should be same as args.local_batch_size divided by 2 - num_examples = scores.size(0) // 2 first_half = scores[:num_examples] second_half = scores[num_examples:] @@ -548,10 +576,10 @@ def generate_completions(self, sampling: bool = False): table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + # _, score, _ = get_reward( + # self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + # ) + # table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) if sampling: break From ac11b75e7f4ff77eac5ed5b55d4826bbecb89156 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 18 Jul 2024 12:54:51 +0000 Subject: [PATCH 16/92] quick fix --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index aec6df1b41d..68e585daa41 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -332,7 +332,7 @@ def repeat_generator(): if self.judge is not None: df = pd.DataFrame( { - "prompt": tokenizer.batch_decode(query[:num_examples], skip_special_tokens=True), + "prompt": tokenizer.batch_decode(queries[:num_examples], skip_special_tokens=True), "response0": tokenizer.batch_decode( postprocessed_responses[:num_examples], skip_special_tokens=True ), From 2a7abcad38db3b4952d5c157219376a9ae0dee4e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 18 Jul 2024 15:54:08 +0000 Subject: [PATCH 17/92] cache changes --- examples/scripts/online_dpo_llmjudge.py | 47 +++++++++++++++++++------ 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/examples/scripts/online_dpo_llmjudge.py b/examples/scripts/online_dpo_llmjudge.py index 3a601f1084c..d090d9dbb99 100644 --- a/examples/scripts/online_dpo_llmjudge.py +++ b/examples/scripts/online_dpo_llmjudge.py @@ -54,26 +54,47 @@ --beta 0.1 \ --response_length 53 \ --push_to_hub -python \ + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/online_dpo_llmjudge.py \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_llmjudge_tldr \ + --output_dir models/minimal/online_dpo_llmjudge_tldr_6.9b \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 16 \ - --local_rollout_forward_batch_size 32 \ + --local_rollout_forward_batch_size 8 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ --total_episodes 1000000 \ - --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --model_name_or_path EleutherAI/pythia-6.9b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ --save_strategy no \ --non_eos_penalty \ --stop_token eos \ --beta 0.1 \ --response_length 53 \ --push_to_hub + + +python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 +python examples/scripts/online_dpo_llmjudge.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --learning_rate 3e-6 \ + --output_dir models/minimal/online_dpo_llmjudge \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-14m \ + --sft_model_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --non_eos_penalty \ + --stop_token eos \ + --response_length 53 \ + --sanity_check \ + --base_url https://ip-26-0-166-125/v1 \ + --api_key token-abc123 \ + --model NousResearch/Meta-Llama-3-8B-Instruct """ @@ -134,6 +155,8 @@ class LLMJudgeConfig: model: str = "gpt-3.5-turbo-0125" max_parallel_requests: Optional[int] = None llm_judge_template: str = "" + base_url: Optional[str] = None + api_key: Optional[str] = None def __post_init__(self): if "gpt-3.5" in self.model: @@ -142,12 +165,14 @@ def __post_init__(self): self.max_parallel_requests = 11 elif "gpt-4" in self.model: self.max_parallel_requests = 13 + else: # assume self-hosted + self.max_parallel_requests = 11 class LLMJudge: def __init__(self, ljc: LLMJudgeConfig): self.ljc = ljc - self.async_client = AsyncOpenAI() + self.async_client = AsyncOpenAI(api_key=ljc.api_key, base_url=ljc.base_url) async def process_text(self, post: str, response0: str, response1: str, i: int, limiter=None): text = self.ljc.llm_judge_template.replace("{{post}}", post) @@ -223,8 +248,8 @@ async def main(ljc: LLMJudgeConfig, df: pd.DataFrame): if __name__ == "__main__": - parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) - args, config, model_config = parser.parse_args_and_config() + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig, LLMJudgeConfig)) + args, config, model_config, judge_config = parser.parse_args_and_config() ################ # Model & Tokenizer @@ -237,7 +262,9 @@ async def main(ljc: LLMJudgeConfig, df: pd.DataFrame): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - judge = LLMJudge(LLMJudgeConfig(n=-1, max_parallel_requests=20, llm_judge_template=TEMPLATE)) + judge_config.n = -1 + judge_config.llm_judge_template = TEMPLATE + judge = LLMJudge(judge_config) ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) ################ From e74646fb1e7d9918eeb10a6668310b073a85e38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:06:08 +0000 Subject: [PATCH 18/92] new semantics --- docs/source/online_dpo_trainer.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index b2e283b0d81..750c65b43ed 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -1,6 +1,6 @@ # Online DPO Trainer -TRL supports training LLMs with online DPO ([Guo et al., 2024](https://huggingface.co/papers/2402.04792)) with a reward model (RM). The idea of online DPO is to generate completions based on prompts and either have an RM or a LLM judge to rank the responses. Then the policy is updated with the ranked responses using the DPO loss. +TRL supports training LLMs with online DPO ([Guo et al., 2024](https://huggingface.co/papers/2402.04792)) with a reward model (RM). The idea of online DPO is to generate completions based on prompts and either have an RM or a LLM judge to rank the responses. Then the model is updated with the ranked responses using the DPO loss. While [Guo et al. (2024)](https://huggingface.co/papers/2402.04792) used a LLM judge, in this implementation we just used a RM. @@ -32,8 +32,8 @@ python examples/scripts/online_dpo.py \ The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) * `eps`: Tracks the number of episodes per second. -* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. -* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model. +* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model. * `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. * `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. * `objective/scores`: The mean scores returned by the reward model / environment. @@ -181,7 +181,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --per_device_train_batch_size 16 \ --gradient_accumulation_steps 4 \ --local_rollout_forward_batch_size 32 \ - --num_ppo_epochs 1 \ + --num_epochs 1 \ --num_mini_batches 1 \ --total_episodes 1000000 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ @@ -203,7 +203,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --per_device_train_batch_size 4 \ --gradient_accumulation_steps 16 \ --local_rollout_forward_batch_size 8 \ - --num_ppo_epochs 1 \ + --num_epochs 1 \ --num_epochs 1 \ --num_mini_batches 1 \ --total_episodes 1000000 \ From c93c81bc248bfe9deb22f8b960e15452d300f625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:13:52 +0000 Subject: [PATCH 19/92] style and arg order change --- trl/trainer/online_dpo_config.py | 1 - trl/trainer/online_dpo_trainer.py | 25 +++++++++++++------------ trl/trainer/ppov2_trainer.py | 6 +----- trl/trainer/rloo_config.py | 1 + trl/trainer/utils.py | 3 +-- 5 files changed, 16 insertions(+), 20 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index df21eca9585..ec58990467b 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Literal - from trl.trainer.utils import OnPolicyConfig diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f208cad3fc1..a5b10067402 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -5,7 +5,6 @@ from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union - import numpy as np import pandas as pd import torch @@ -22,6 +21,7 @@ Trainer, TrainerCallback, TrainerControl, + default_data_collator, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -40,6 +40,7 @@ print_rich_table, truncate_response, ) +from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig @@ -49,18 +50,17 @@ class OnlineDPOTrainer(Trainer): def __init__( self, - config: OnlineDPOConfig, - tokenizer: PreTrainedTokenizer, model: nn.Module, + config: OnlineDPOConfig, ref_model: nn.Module, reward_model: Optional[nn.Module] = None, - judge: Optional[Any] = None, + judge: Optional[BasePairwiseJudge] = None, data_collator: Optional[DataCollatorWithPadding] = None, - train_dataset: Dataset, + train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - # less commonly used - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + tokenizer: Optional[PreTrainedTokenizer] = None, callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), ) -> None: self.args = config args = config @@ -77,7 +77,8 @@ def __init__( self.train_dataset = train_dataset self.train_dataset_len = len(train_dataset) - self.data_collator = data_collator + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + self.data_collator = data_collator if data_collator is not None else default_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers @@ -100,7 +101,9 @@ def __init__( args.local_mini_batch_size = exact_div( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - args.num_total_batches = math.ceil(args.total_episodes / args.batch_size) # we may train for more than `total_episodes` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` time_tensor = torch.tensor(int(time.time()), device=accelerator.device) time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" @@ -133,7 +136,6 @@ def __init__( num_training_steps=args.num_total_batches ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level - ######### ### trainer specifics ######### @@ -394,7 +396,6 @@ def repeat_generator(): # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for epoch_idx in range(args.num_epochs): - b_inds = np.random.permutation(args.local_batch_size // args.num_generation_per_prompt) minibatch_idx = 0 for mini_batch_start in range( @@ -584,7 +585,7 @@ def generate_completions(self, sampling: bool = False): table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - if reward_model is not None: + if self.reward_model is not None: _, score, _ = get_reward( self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 732916dacf7..2ba0ed2856d 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -534,14 +534,10 @@ def repeat_generator(): if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) torch.cuda.empty_cache() gc.collect() + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: self.generate_completions(sampling=True) diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index ae55cb6521f..e629d84afa0 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from ..trainer.utils import OnPolicyConfig diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 53cb5b8975f..3eaddf0ad1a 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import dataclasses import json import random @@ -33,8 +32,8 @@ BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase, - TrainingArguments, TrainerState, + TrainingArguments, ) from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available From ff479e479f44ced001874b0f795a04938627051e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:14:43 +0000 Subject: [PATCH 20/92] rm duplicated num_epochs --- docs/source/online_dpo_trainer.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 750c65b43ed..e1ee9aed216 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -204,7 +204,6 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --gradient_accumulation_steps 16 \ --local_rollout_forward_batch_size 8 \ --num_epochs 1 \ - --num_epochs 1 \ --num_mini_batches 1 \ --total_episodes 1000000 \ --model_name_or_path EleutherAI/pythia-6.9b-deduped \ From f39c61a0d9cf6dd9d89bc03b769b915b9d4e1d65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:18:41 +0000 Subject: [PATCH 21/92] rm plot script --- docs/source/ppov2_trainer.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/docs/source/ppov2_trainer.md b/docs/source/ppov2_trainer.md index 39d804b3050..5cab83ad4c4 100644 --- a/docs/source/ppov2_trainer.md +++ b/docs/source/ppov2_trainer.md @@ -201,18 +201,6 @@ Model win rate: 33.00% $ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 64.70% ``` -import matplotlib.pyplot as plt - -ys = [34.4, 53.2, 52.8] -xs = ["SFT policy", "RLOO policy 1B", "PPO Policy 1B"] - -plt.bar(xs, ys) -plt.ylabel('Win rate against reference summaries') -plt.xlabel('Model Name') -plt.title('Win Rate Comparison') - -plt.show() -``` The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended. From e515d0ee929ef1cbffee05985c80125be3a4a56d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:19:14 +0000 Subject: [PATCH 22/92] num_epoch --- examples/scripts/online_dpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 8937fe7fca5..72a7d3c0a65 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -36,7 +36,7 @@ --output_dir models/minimal/online_dpo \ --per_device_train_batch_size 16 \ --local_rollout_forward_batch_size 32 \ - --num_ppo_epochs 1 \ + --num_epochs 1 \ --num_mini_batches 1 \ --gradient_accumulation_steps 4 \ --total_episodes 1000000 \ From 0d8ae8c37380ddcb31f6b2c49c7d9a56ce23e2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:27:53 +0000 Subject: [PATCH 23/92] revert some changes --- trl/trainer/online_dpo_trainer.py | 2 - trl/trainer/online_trainer.py | 469 ------------------------------ trl/trainer/rloo_trainer.py | 1 + 3 files changed, 1 insertion(+), 471 deletions(-) delete mode 100644 trl/trainer/online_trainer.py diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index a5b10067402..318b7ead0da 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -336,8 +336,6 @@ def repeat_generator(): sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - scores = torch.zeros(args.local_batch_size).to(device) - num_examples = postprocessed_responses.size(0) // 2 if self.judge is not None: df = pd.DataFrame( { diff --git a/trl/trainer/online_trainer.py b/trl/trainer/online_trainer.py deleted file mode 100644 index 4ed603ccc9a..00000000000 --- a/trl/trainer/online_trainer.py +++ /dev/null @@ -1,469 +0,0 @@ -import gc -import os -import time -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from accelerate import Accelerator -from accelerate.utils import broadcast, gather_object -from datasets import Dataset -from torch.utils.data import DataLoader -from transformers import ( - DataCollatorWithPadding, - GenerationConfig, - PreTrainedTokenizer, - Trainer, - TrainerCallback, - TrainerControl, - TrainerState, -) -from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer_callback import CallbackHandler, DefaultFlowCallback - -from ..models.utils import unwrap_model_for_generation -from ..trainer.utils import ( - disable_dropout_in_model, - exact_div, - first_true_indices, - forward, - generate, - get_reward, - prepare_deepspeed, - print_rich_table, - truncate_response, -) -from .rloo_config import RLOOConfig - - -INVALID_LOGPROB = 1.0 - - -class RLOOTrainer(Trainer): - def __init__( - self, - config: RLOOConfig, - tokenizer: PreTrainedTokenizer, - policy: nn.Module, - ref_policy: nn.Module, - reward_model: nn.Module, - train_dataset: Dataset, - data_collator: Optional[DataCollatorWithPadding] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - # less commonly used - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - # compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - # model_init: Optional[Callable[[torch.nn.Module], None]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - ) -> None: - self.args = config - args = config - self.tokenizer = tokenizer - self.policy = policy - - self.policy.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - - self.ref_policy = ref_policy - self.reward_model = reward_model - self.train_dataset = train_dataset - self.train_dataset_len = len(train_dataset) - self.data_collator = data_collator - self.eval_dataset = eval_dataset - self.optimizer, self.lr_scheduler = optimizers - self.callbacks = callbacks - - ######### - # calculate various batch sizes - ######### - if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - self.accelerator = accelerator - args.world_size = accelerator.num_processes - args.local_batch_size = ( - args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches - ) - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.batch_size = int(args.local_batch_size * args.world_size) - args.mini_batch_size = exact_div( - args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" - ) - args.local_mini_batch_size = exact_div( - args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" - ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` - args.num_updates = args.total_episodes // args.batch_size - time_tensor = torch.tensor(int(time.time()), device=accelerator.device) - time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes - args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" - self.local_seed = args.seed + accelerator.process_index * 100003 # Prime - if args.num_sample_generations > 0: - self.sample_generations_freq = max(1, args.num_updates // args.num_sample_generations) - self.local_dataloader_batch_size = exact_div( - args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k" - ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times - - ######### - # setup model, optimizer, and others - ######### - for module in [policy, ref_policy, reward_model]: - disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = tokenizer.eos_token_id - self.model = policy - self.create_optimizer_and_scheduler(num_training_steps=args.num_updates) - - ######### - ### trainer specifics - ######### - self.state = TrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - ) - DEFAULT_CALLBACKS = [DefaultFlowCallback] - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - if self.callbacks is None: - self.callbacks = default_callbacks - self.callback_handler = CallbackHandler( - self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler - ) - self.control = TrainerControl() - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # Create distant repo and output directory if needed - self.hub_model_id = None - if self.args.push_to_hub: - self.init_hf_repo() - if self.args.should_save: - os.makedirs(self.args.output_dir, exist_ok=True) - self.backup_model = None - - ######### - ### setup dataloader - ######### - self.dataloader = DataLoader( - self.train_dataset, - batch_size=self.local_dataloader_batch_size, - shuffle=True, - collate_fn=DataCollatorWithPadding(tokenizer), - drop_last=True, # needed; otherwise the last batch will be of ragged shape - ) - # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` - # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c - torch.manual_seed(args.seed) - self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) - torch.manual_seed(self.local_seed) # reset the local seed again - - self.eval_dataloader = DataLoader( - self.eval_dataset, - batch_size=args.per_device_eval_batch_size, - collate_fn=DataCollatorWithPadding(self.tokenizer), - drop_last=True, - ) # no need to shuffle eval dataset - self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - - if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 - ) - self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 - ) - self.deepspeed = self.model - else: - self.ref_policy = self.ref_policy.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) - - def get_train_dataloader(self) -> DataLoader: - return self.dataloader - - def get_eval_dataloader(self) -> DataLoader: - return self.eval_dataloader - - def train(self): - args = self.args - accelerator = self.accelerator - optimizer = self.optimizer - model = self.model - ref_policy = self.ref_policy - reward_model = self.reward_model - tokenizer = self.tokenizer - dataloader = self.dataloader - device = accelerator.device - - def repeat_generator(): - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( - max_new_tokens=args.response_length, - min_new_tokens=args.response_length, - temperature=(args.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - accelerator.print("===training policy===") - global_step = 0 - start_time = time.time() - stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - model.train() - for update in range(1, args.num_updates + 1): - global_step += 1 * args.batch_size - self.lr_scheduler.step() - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - queries = queries.repeat(args.rloo_k, 1) - context_length = queries.shape[1] - query_responses = [] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response, logits = generate( - unwrapped_model, - query, - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - - # use the logits during generation directly, instead of using the following - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) - - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - - query_responses.append(query_response) - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - query_responses = torch.cat(query_responses, 0) - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - del (logprob, ref_logprob, score) - torch.cuda.empty_cache() - gc.collect() - - # Response Processing 3. filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) - if args.non_eos_penalty: - scores = torch.where(contain_eos_token, scores, torch.full_like(scores, args.penalty_reward_value)) - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = (-args.kl_coef * kl).sum(1) - rlhf_reward = scores + non_score_reward - - # vectorized RLOO advantages implementation - rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1) - baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1) - advantages = rlhf_reward - baseline - advantages = advantages.flatten() - torch.cuda.empty_cache() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - - output = forward(model, mb_query_responses, tokenizer.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - new_logprobs = torch.masked_fill( - new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB - ) - new_ratio = (new_logprobs - mb_logprobs).exp() - new_logprobs = new_logprobs.sum(1) - mb_logprobs = mb_logprobs.sum(1) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = pg_loss_max.mean() - loss = pg_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = pg_clipfrac - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # del everything and empty cache - # fmt: off - del ( - output, logits, new_all_logprobs, new_logprobs, - logprobs_diff, ratio, pg_losses, pg_losses2, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, - mb_advantage, mb_responses, mb_query_responses, mb_logprobs, - ) - # fmt: on - torch.cuda.empty_cache() - with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.mean() - eps = int(global_step / (time.time() - start_time)) - metrics = {} - metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() - metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() - metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = global_step - self.state.epoch = global_step / self.train_dataset_len # used by self.log - self.log(metrics) - del kl, mean_kl, mean_entropy, scores - torch.cuda.empty_cache() - gc.collect() - - if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: - self.generate_completions(sampling=True) - - def generate_completions(self, sampling: bool = False): - args = self.args - tokenizer = self.tokenizer - generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - temperature=(0.01 + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - table = defaultdict(list) - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - query_response, _ = generate( - unwrapped_model, - query, - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) - table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - - if sampling: - break - df = pd.DataFrame(table) - if self.accelerator.process_index == 0: - print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 17164cc8c2b..c4d6aa39fe9 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -437,6 +437,7 @@ def repeat_generator(): metrics["episode"] = self.state.episode self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.log(metrics) + self.state.global_step += 1 del kl, mean_kl, mean_entropy, scores self.control = self.callback_handler.on_step_end(args, self.state, self.control) From 0641b55b2adcf1a088cb34b4493d9ddaef357be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:34:49 +0000 Subject: [PATCH 24/92] revert changes --- docs/source/online_dpo_trainer.md | 2 ++ trl/trainer/online_dpo_trainer.py | 1 + trl/trainer/ppov2_trainer.py | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index e1ee9aed216..d608a980264 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -63,6 +63,7 @@ To help you understand what your model is doing, we periodically log some sample ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif) + In the logs the sampled generations look like ``` @@ -171,6 +172,7 @@ Many online implementation details are borrowed from the PPOv2Trainer, which is To validate the online DPO implementation works, we ran experiments on the 1B and 6.9B models. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + ``` # 1B Online DPO experiment accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 318b7ead0da..f022980bf33 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -337,6 +337,7 @@ def repeat_generator(): scores = torch.cat(scores, 0) if self.judge is not None: + num_examples = postprocessed_responses.size(0) // 2 df = pd.DataFrame( { "prompt": tokenizer.batch_decode(queries[:num_examples], skip_special_tokens=True), diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 2ba0ed2856d..0cabbd325ba 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -534,10 +534,10 @@ def repeat_generator(): if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward torch.cuda.empty_cache() gc.collect() - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: self.generate_completions(sampling=True) From 25af762f7f4be0a77bc1aec97d4cc4e4d46bb479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 09:50:08 +0000 Subject: [PATCH 25/92] revert whitespace --- examples/scripts/online_dpo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 72a7d3c0a65..915ca46e0bc 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -28,7 +28,6 @@ --stop_token eos \ --response_length 53 \ --sanity_check - accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/online_dpo.py \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ From 29a1244d5e5a3edcafc95812dddbf29372e0a94d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 10:11:19 +0000 Subject: [PATCH 26/92] rm whitespace --- trl/trainer/ppov2_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 0cabbd325ba..c710e2929f5 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -535,7 +535,6 @@ def repeat_generator(): self._save_checkpoint(model, trial=None, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward - torch.cuda.empty_cache() gc.collect() From d858b2876836373e56a5c39794fc273cf53f713e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 10:12:06 +0000 Subject: [PATCH 27/92] revert change --- trl/trainer/rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index c4d6aa39fe9..2d964e559fe 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -436,8 +436,8 @@ def repeat_generator(): metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log - self.log(metrics) self.state.global_step += 1 + self.log(metrics) del kl, mean_kl, mean_entropy, scores self.control = self.callback_handler.on_step_end(args, self.state, self.control) From 4020f41f1aee180ad71978cb17eb17e1e5029c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 1 Aug 2024 10:13:54 +0000 Subject: [PATCH 28/92] policy->model --- trl/trainer/online_dpo_trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f022980bf33..3ba2f1fb47d 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -191,12 +191,10 @@ def __init__( self.reward_model = prepare_deepspeed( self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) - self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 - ) + self.ref_model = prepare_deepspeed(self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16) self.deepspeed = self.model else: - self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.ref_model = self.ref_model.to(self.accelerator.device) if reward_model is not None: self.reward_model = self.reward_model.to(self.accelerator.device) From b1a264ad61d636acbd1edb2cea564ba6a2edf459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 2 Aug 2024 12:48:26 +0000 Subject: [PATCH 29/92] optional judge and reward model --- trl/trainer/online_dpo_config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index ec58990467b..0c23c9b313a 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional from trl.trainer.utils import OnPolicyConfig @@ -9,8 +9,9 @@ class OnlineDPOConfig(OnPolicyConfig): exp_name: str = os.path.basename(__file__)[: -len(".py")] """the name of this experiment""" - reward_model_path: str = "EleutherAI/pythia-160m" + reward_model_path: Optional[str] = None """the path to the reward model""" + judge: Optional[str] = None num_epochs: int = 4 """the number of epochs to train""" From 79082f801fcbc0ffcbe089973fa0040f226f3ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 2 Aug 2024 12:48:41 +0000 Subject: [PATCH 30/92] cleaning online dpo script --- examples/scripts/online_dpo.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 915ca46e0bc..f46b6a05ef0 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -8,13 +8,14 @@ AutoTokenizer, ) -from trl import ModelConfig +from trl import HfPairwiseJudge, ModelConfig from trl.commands.cli_utils import TrlParser from trl.trainer import OnlineDPOConfig, OnlineDPOTrainer from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE """ +# Sanity check with minimal config and model python examples/scripts/online_dpo.py \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ --learning_rate 3e-6 \ @@ -28,25 +29,21 @@ --stop_token eos \ --response_length 53 \ --sanity_check + accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/online_dpo.py \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo \ --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ --local_rollout_forward_batch_size 32 \ --num_epochs 1 \ - --num_mini_batches 1 \ - --gradient_accumulation_steps 4 \ --total_episodes 1000000 \ - --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ - --save_strategy no \ --non_eos_penalty \ - --stop_token eos \ - --beta 0.1 \ - --response_length 53 \ - --push_to_hub + --stop_token eos """ @@ -93,10 +90,20 @@ def tokenize(element): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1) + ref_model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) + if config.reward_model_path is not None: + reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1) + else: + reward_model = None + + if config.judge is not None: + judge = HfPairwiseJudge() + else: + judge = None + ################ # Dataset ################ @@ -117,13 +124,14 @@ def tokenize(element): ################ trainer = OnlineDPOTrainer( - config=config, - tokenizer=tokenizer, model=model, + config=config, ref_model=ref_model, reward_model=reward_model, + judge=judge, train_dataset=train_dataset, eval_dataset=eval_dataset, + tokenizer=tokenizer, ) trainer.train() if not config.sanity_check: From 9554c80ef3a28c92351929d3b098c126ed0d6d34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 2 Aug 2024 12:49:02 +0000 Subject: [PATCH 31/92] warning when both reward mdoel and judge provided --- trl/trainer/online_dpo_trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 3ba2f1fb47d..f90f2cd9e01 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -2,6 +2,7 @@ import math import os import time +import warnings from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -72,8 +73,15 @@ def __init__( model.generation_config.pad_token_id = None self.ref_model = ref_model + self.reward_model = reward_model self.judge = judge + if self.reward_model is not None and self.judge is not None: + warnings.warn( + "Both `reward_model` and `judge` are provided. Please choose provide only one of them. " + "Ignoring `judge` and using `reward_model`." + ) + self.train_dataset = train_dataset self.train_dataset_len = len(train_dataset) From 2d0a8a11deb966b45fb6a6f8593add893789c2bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 9 Aug 2024 14:43:18 +0000 Subject: [PATCH 32/92] return -1 when the judge fails --- trl/trainer/judges.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index ad76f792f09..0924ce61015 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -110,6 +110,11 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order: Returns: List of idxs, where each idx is the rank of the best completion for the corresponding prompt. E.g., 1 means that the second completion (idx=1) is the best. + + Note: + If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed. + For instance, this could occur if the underlying language model returned an invalid answer. + In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling. """ raise NotImplementedError("Judge subclasses must implement the `judge` method.") @@ -201,8 +206,8 @@ def get_rank(prompt, candidates): if response in ["0", "1"]: return int(response) else: - logging.warning(f"Invalid response from the model: {response}, using random choice instead.") - return random.choice([0, 1]) + logging.debug(f"Invalid response from the judge model: '{response}'. Using random choice instead.") + return -1 # Call the completions concurrently with concurrent.futures.ThreadPoolExecutor() as executor: @@ -228,7 +233,6 @@ class OpenAIPairwiseJudge(BasePairwiseJudge): Note that the system prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token response. max_requests (`int`, *optional*): The maximum number of requests to make to the OpenAI API. Defaults to 1000. If set to `None`, there is no limit. - """ def __init__( @@ -268,8 +272,8 @@ def get_rank(prompt, candidates): if response in ["0", "1"]: return int(response) else: - logging.warning(f"Invalid response from the model: {response}, using random choice instead.") - return random.choice([0, 1]) + logging.debug(f"Invalid response from the judge model: '{response}'. Using random choice instead.") + return -1 # Call the completions concurrently with concurrent.futures.ThreadPoolExecutor() as executor: From 9580a8dff2818659b9ed6d1434aefe3ee67ff4f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 9 Aug 2024 14:43:30 +0000 Subject: [PATCH 33/92] dataset num proc --- trl/trainer/online_dpo_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 78ca52b9c63..a79e76fdf77 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -22,3 +22,4 @@ class OnlineDPOConfig(OnPolicyConfig): """the type of loss to use for online DPO""" disable_dropout: bool = True """whether to disable dropout of the model during training""" + dataset_num_proc: Optional[int] = None From 14be7b7cb3e4479ee76f4d87320395573be1acbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 9 Aug 2024 14:47:59 +0000 Subject: [PATCH 34/92] add judges in online dpo; fix collate and process within the trainer --- examples/scripts/online_dpo.py | 45 ++++--------- trl/trainer/online_dpo_trainer.py | 101 +++++++++++++++++++++--------- 2 files changed, 81 insertions(+), 65 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index f46b6a05ef0..e0ad57194db 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -19,12 +19,12 @@ python examples/scripts/online_dpo.py \ --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ + --output_dir online_dpo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 16 \ --total_episodes 30000 \ --model_name_or_path EleutherAI/pythia-14m \ - --reward_model_path EleutherAI/pythia-14m \ + --judge hf_pairwise \ --non_eos_penalty \ --stop_token eos \ --response_length 53 \ @@ -56,25 +56,6 @@ class ScriptArguments: max_length: int = 512 -def prepare_dataset(dataset, tokenizer, dataset_text_field): - """pre-tokenize the dataset before training; only collate during training""" - - def tokenize(element): - outputs = tokenizer( - element[dataset_text_field], - padding=False, - ) - return {"input_ids": outputs["input_ids"]} - - return dataset.map( - tokenize, - remove_columns=dataset.column_names, - batched=True, - num_proc=4, # multiprocessing.cpu_count(), - load_from_cache_file=False, - ) - - if __name__ == "__main__": parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) args, config, model_config = parser.parse_args_and_config() @@ -107,18 +88,14 @@ def tokenize(element): ################ # Dataset ################ - raw_datasets = load_dataset(args.dataset_name) + ds = load_dataset(args.dataset_name) if config.sanity_check: - for key in raw_datasets: - raw_datasets[key] = raw_datasets[key].select(range(1024)) - train_dataset = raw_datasets[args.dataset_train_split] - train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field) - - if args.dataset_test_split is not None: - eval_dataset = raw_datasets[args.dataset_test_split] - eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field) - else: - eval_dataset = None + for key in ds: + ds[key] = ds[key].select(range(1024)) + + train_dataset = ds[args.dataset_train_split] + eval_dataset = ds[args.dataset_test_split] + ################ # Training ################ diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f5cb15c9524..e4fc702fb4c 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from accelerate import Accelerator +from accelerate import Accelerator, PartialState from accelerate.utils import broadcast, gather_object from datasets import Dataset from torch.utils.data import DataLoader @@ -28,6 +28,8 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, PrinterCallback +from trl.trainer.utils import DPODataCollatorWithPadding + from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( OnlineTrainerState, @@ -73,7 +75,6 @@ def __init__( model.generation_config.pad_token_id = None self.ref_model = ref_model - self.reward_model = reward_model self.judge = judge if self.reward_model is not None and self.judge is not None: @@ -81,13 +82,13 @@ def __init__( "Both `reward_model` and `judge` are provided. Please choose provide only one of them. " "Ignoring `judge` and using `reward_model`." ) + elif self.reward_model is None and self.judge is None: + raise ValueError("Either `reward_model` or `judge` must be provided.") - self.train_dataset = train_dataset self.train_dataset_len = len(train_dataset) - default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + default_collator = default_data_collator if tokenizer is None else DPODataCollatorWithPadding() self.data_collator = data_collator if data_collator is not None else default_collator - self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers self.num_generation_per_prompt = 2 @@ -135,7 +136,8 @@ def __init__( if args.disable_dropout: disable_dropout_in_model(model) self.ref_model.eval() - self.reward_model.eval() + if self.reward_model is not None: + self.reward_model.eval() if args.stop_token_id is None and args.stop_token and args.stop_token == "eos": args.stop_token_id = tokenizer.eos_token_id @@ -174,11 +176,19 @@ def __init__( ######### ### setup dataloader ######### + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + # tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) self.dataloader = DataLoader( - self.train_dataset, + train_dataset, batch_size=self.local_dataloader_batch_size, shuffle=True, - collate_fn=DataCollatorWithPadding(tokenizer), + collate_fn=self.data_collator, drop_last=True, # needed; otherwise the last batch will be of ragged shape ) # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` @@ -188,15 +198,15 @@ def __init__( torch.manual_seed(self.local_seed) # reset the local seed again self.eval_dataloader = DataLoader( - self.eval_dataset, + eval_dataset, batch_size=args.per_device_eval_batch_size, - collate_fn=DataCollatorWithPadding(self.tokenizer), + collate_fn=self.data_collator, drop_last=True, ) # no need to shuffle eval dataset self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - if reward_model is not None: + if self.reward_model is not None: self.reward_model = prepare_deepspeed( self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) @@ -204,9 +214,26 @@ def __init__( self.deepspeed = self.model else: self.ref_model = self.ref_model.to(self.accelerator.device) - if reward_model is not None: + if self.reward_model is not None: self.reward_model = self.reward_model.to(self.accelerator.device) + def tokenize_row(self, feature) -> Dict: + """Tokenize a single row from a DPO specific dataset.""" + if not self.model.config.is_encoder_decoder: + batch = self.tokenizer(feature["prompt"], add_special_tokens=False) + # Add BOS token to head of prompt. Avoid adding if it's already there + prompt_len_input_ids = len(batch["input_ids"]) + if self.tokenizer.bos_token_id is not None: + if prompt_len_input_ids == 0 or self.tokenizer.bos_token_id != batch["input_ids"][0]: + batch["input_ids"] = [self.tokenizer.bos_token_id] + batch["input_ids"] + batch["attention_mask"] = [1] + batch["attention_mask"] + else: + batch = self.tokenizer( + feature["prompt"], truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + batch = {f"prompt_{key}": value for key, value in batch.items()} + return batch + def get_train_dataloader(self) -> DataLoader: return self.dataloader @@ -276,7 +303,7 @@ def repeat_generator(): self.lr_scheduler.step() data = next(iter_dataloader) with torch.no_grad(): - queries = data["input_ids"].to(device) + queries = data["prompt_input_ids"].to(device) queries = queries.repeat(self.num_generation_per_prompt, 1) context_length = queries.shape[1] responses = [] @@ -326,13 +353,16 @@ def repeat_generator(): _, score, _ = get_reward( reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) + else: + score = None responses.append(response) postprocessed_responses.append(postprocessed_response) logprobs.append(logprob) ref_logprobs.append(ref_logprob) sequence_lengths.append(sequence_length) - scores.append(score) + if score is not None: + scores.append(score) # stack all the tensors responses = torch.cat(responses, 0) @@ -340,24 +370,29 @@ def repeat_generator(): logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) + if score is not None: + scores = torch.cat(scores, 0) if self.judge is not None: num_examples = postprocessed_responses.size(0) // 2 - df = pd.DataFrame( - { - "prompt": tokenizer.batch_decode(queries[:num_examples], skip_special_tokens=True), - "response0": tokenizer.batch_decode( - postprocessed_responses[:num_examples], skip_special_tokens=True - ), - "response1": tokenizer.batch_decode( - postprocessed_responses[num_examples:], skip_special_tokens=True - ), - } + candidates_0 = tokenizer.batch_decode( + postprocessed_responses[:num_examples], skip_special_tokens=True ) - judge_df = self.judge.judge(df) - scores[:num_examples] = torch.tensor(judge_df["preferred"] == "response0", dtype=torch.float) - scores[num_examples:] = torch.tensor(judge_df["preferred"] == "response1", dtype=torch.float) + candidates_1 = tokenizer.batch_decode( + postprocessed_responses[num_examples:], skip_special_tokens=True + ) + completions = [[c0, c1] for c0, c1 in zip(candidates_0, candidates_1)] + preferences = self.judge.judge( + prompts=data["prompt"], completions=completions + ) # preferences is a list of prefered indexes + preferences = torch.tensor(preferences, dtype=torch.float32, device=device) + # Get the number of invalid answers by counting the number of -1 in preferences + invalid_rate = (preferences == -1).sum() / len(preferences) + # Replace invalid preferences with random preferences + preferences = torch.where( + preferences == -1, torch.randint(0, 2, preferences.shape, device=device), preferences + ) + scores = torch.cat((preferences, 1 - preferences)) del (logprob, ref_logprob, score) torch.cuda.empty_cache() @@ -535,6 +570,8 @@ def repeat_generator(): metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode + if self.judge is not None: + metrics["judge/invalid_rate"] = invalid_rate.item() self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.state.global_step += 1 self.log(metrics) @@ -570,7 +607,7 @@ def generate_completions(self, sampling: bool = False): table = defaultdict(list) with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: for batch in self.eval_dataloader: - query = batch["input_ids"] + query = batch["prompt_input_ids"] with torch.no_grad(): context_length = query.shape[1] query_response, _ = batch_generation( @@ -586,8 +623,10 @@ def generate_completions(self, sampling: bool = False): postprocessed_response = truncate_response( args.stop_token_id, tokenizer.pad_token_id, response ) - table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) + table["Query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) + table["Model response"].extend( + gather_object(tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True)) + ) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) if self.reward_model is not None: From 1b7cdcf23f11052ad951f4ab3bf8703c4fa96208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 11 Aug 2024 15:47:41 +0000 Subject: [PATCH 35/92] lr_scheduler.step() after optimizer step --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index e4fc702fb4c..a2f67d7d4b6 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -300,7 +300,6 @@ def repeat_generator(): for update in range(1, args.num_total_batches + 1): self.state.episode += 1 * args.batch_size - self.lr_scheduler.step() data = next(iter_dataloader) with torch.no_grad(): queries = data["prompt_input_ids"].to(device) @@ -577,6 +576,7 @@ def repeat_generator(): self.log(metrics) del (kl, mean_kl, mean_entropy, scores, scores_margin) + self.lr_scheduler.step() self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=metrics) From e58d4731233a2e24d00887a52a61fdec52b3ae65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 11 Aug 2024 15:48:04 +0000 Subject: [PATCH 36/92] update odpo test --- tests/test_online_dpo_trainer.py | 37 ++++++-------------------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 31cd2316120..e69ddfe2f99 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -28,11 +28,6 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token - def _get_dummy_model_and_tokenizer(self): - # Return dummy model and tokenizer. This is a placeholder. - return self.model, self.tokenizer, self.reward_model - - def _init_dummy_dataset(self): # fmt: off dummy_dataset_dict = { "prompt": [ @@ -70,27 +65,9 @@ def _init_dummy_dataset(self): ], } # fmt: on - return Dataset.from_dict(dummy_dataset_dict) + self.dummy_dataset = Dataset.from_dict(dummy_dataset_dict) def test_online_dpo_trainer_training(self): - model, tokenizer, reward_model = self._get_dummy_model_and_tokenizer() - dummy_dataset = self._init_dummy_dataset() - - def tokenize(element): - outputs = tokenizer( - element["prompt"], - padding=False, - ) - return {"input_ids": outputs["input_ids"]} - - dummy_dataset = dummy_dataset.map( - tokenize, - remove_columns=dummy_dataset.column_names, - batched=True, - num_proc=4, # multiprocessing.cpu_count(), - load_from_cache_file=False, - ) - with tempfile.TemporaryDirectory() as tmp_dir: training_args = OnlineDPOConfig( output_dir=tmp_dir, @@ -104,13 +81,13 @@ def tokenize(element): ) trainer = OnlineDPOTrainer( - model=model, - ref_model=model, - reward_model=reward_model, + model=self.model, + ref_model=self.model, + reward_model=self.reward_model, config=training_args, - tokenizer=tokenizer, - train_dataset=dummy_dataset, - eval_dataset=dummy_dataset, + tokenizer=self.tokenizer, + train_dataset=self.dummy_dataset, + eval_dataset=self.dummy_dataset, ) trainer.train() From ced7c9827f12e3079bc1a7cecd6afca9b004e4b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 17:37:13 +0000 Subject: [PATCH 37/92] reduce nestiness --- trl/trainer/online_dpo_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index a2f67d7d4b6..46c4ff60219 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -623,10 +623,12 @@ def generate_completions(self, sampling: bool = False): postprocessed_response = truncate_response( args.stop_token_id, tokenizer.pad_token_id, response ) - table["Query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) - table["Model response"].extend( - gather_object(tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True)) + query_text = tokenizer.batch_decode(query, skip_special_tokens=True) + postprocessed_response_text = tokenizer.batch_decode( + postprocessed_response, skip_special_tokens=True ) + table["Query"].extend(gather_object(query_text)) + table["Model response"].extend(gather_object(postprocessed_response_text)) postprocessed_query_response = torch.cat((query, postprocessed_response), 1) if self.reward_model is not None: From e14eb43be2986d9034a0ed319b7a933857a63337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 19:41:10 +0000 Subject: [PATCH 38/92] allow pickle --- trl/trainer/online_dpo_trainer.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 46c4ff60219..d8a0eb297a2 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -180,10 +180,14 @@ def __init__( # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): - # tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models) - train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + # tokenize the dataset + fn_kwargs = { + "is_encoder_decoder": self.model.config.is_encoder_decoder, + "tokenizer": tokenizer, + } + train_dataset = train_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) if eval_dataset is not None: - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) self.dataloader = DataLoader( train_dataset, batch_size=self.local_dataloader_batch_size, @@ -217,20 +221,19 @@ def __init__( if self.reward_model is not None: self.reward_model = self.reward_model.to(self.accelerator.device) - def tokenize_row(self, feature) -> Dict: + @staticmethod + def tokenize_row(feature, is_encoder_decoder, tokenizer) -> Dict: """Tokenize a single row from a DPO specific dataset.""" - if not self.model.config.is_encoder_decoder: - batch = self.tokenizer(feature["prompt"], add_special_tokens=False) + if not is_encoder_decoder: + batch = tokenizer(feature["prompt"], add_special_tokens=False) # Add BOS token to head of prompt. Avoid adding if it's already there prompt_len_input_ids = len(batch["input_ids"]) - if self.tokenizer.bos_token_id is not None: - if prompt_len_input_ids == 0 or self.tokenizer.bos_token_id != batch["input_ids"][0]: - batch["input_ids"] = [self.tokenizer.bos_token_id] + batch["input_ids"] + if tokenizer.bos_token_id is not None: + if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: + batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] batch["attention_mask"] = [1] + batch["attention_mask"] else: - batch = self.tokenizer( - feature["prompt"], truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) + batch = tokenizer(feature["prompt"], add_special_tokens=True) batch = {f"prompt_{key}": value for key, value in batch.items()} return batch @@ -566,6 +569,7 @@ def repeat_generator(): metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() + print(responses.tolist()) metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode From 629b6f17230ecca15bc2188b15e962f6dd1a34d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 19:41:31 +0000 Subject: [PATCH 39/92] generation config typing --- trl/trainer/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d7ef89e8619..0bbff567c8d 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -31,6 +31,7 @@ from transformers import ( BitsAndBytesConfig, DataCollatorForLanguageModeling, + GenerationConfig, PreTrainedTokenizerBase, TrainerState, TrainingArguments, @@ -1108,7 +1109,7 @@ def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Te def generate( - lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: dict + lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: GenerationConfig ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generates sequences from the language model backbone in a way that does not affect padding tokens. @@ -1120,8 +1121,8 @@ def generate( The tensor containing the input queries. pad_token_id (`int`): The token ID representing the pad token. - generation_config (`dict`): - The configuration dictionary for generation settings. + generation_config (`GenerationConfig`): + The configuration for the generation process. Returns: tuple: @@ -1152,7 +1153,7 @@ def batch_generation( queries: torch.Tensor, local_rollout_forward_batch_size: int, pad_token_id: int, - generation_config: dict, + generation_config: GenerationConfig, ): query_responses = [] logitss = [] From 4459efd8a8dbbce1d75695929a508be534c9f512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 19:56:11 +0000 Subject: [PATCH 40/92] online dpo llm judge --- examples/scripts/online_dpo_llmjudge.py | 303 ------------------------ 1 file changed, 303 deletions(-) delete mode 100644 examples/scripts/online_dpo_llmjudge.py diff --git a/examples/scripts/online_dpo_llmjudge.py b/examples/scripts/online_dpo_llmjudge.py deleted file mode 100644 index d090d9dbb99..00000000000 --- a/examples/scripts/online_dpo_llmjudge.py +++ /dev/null @@ -1,303 +0,0 @@ -import asyncio -import random -import time -from dataclasses import dataclass -from typing import Optional - -import pandas as pd -from datasets import load_dataset -from openai import AsyncOpenAI -from tqdm.asyncio import tqdm_asyncio -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) - -from trl import ModelConfig -from trl.commands.cli_utils import TrlParser -from trl.trainer.online_dpo_trainer import OnlineDPOConfig, OnlineDPOTrainer -from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE - - -""" -python examples/scripts/online_dpo_llmjudge.py \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_llmjudge \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ - --total_episodes 30000 \ - --model_name_or_path EleutherAI/pythia-14m \ - --sft_model_path EleutherAI/pythia-14m \ - --reward_model_path EleutherAI/pythia-14m \ - --non_eos_penalty \ - --stop_token eos \ - --response_length 53 \ - --sanity_check -accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ - examples/scripts/online_dpo_llmjudge.py \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_llmjudge \ - --per_device_train_batch_size 16 \ - --local_rollout_forward_batch_size 32 \ - --num_ppo_epochs 1 \ - --num_mini_batches 1 \ - --gradient_accumulation_steps 4 \ - --total_episodes 1000000 \ - --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ - --save_strategy no \ - --non_eos_penalty \ - --stop_token eos \ - --beta 0.1 \ - --response_length 53 \ - --push_to_hub - -accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/online_dpo_llmjudge.py \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_llmjudge_tldr_6.9b \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 16 \ - --local_rollout_forward_batch_size 8 \ - --num_ppo_epochs 1 \ - --num_mini_batches 1 \ - --total_episodes 1000000 \ - --model_name_or_path EleutherAI/pythia-6.9b-deduped \ - --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ - --save_strategy no \ - --non_eos_penalty \ - --stop_token eos \ - --beta 0.1 \ - --response_length 53 \ - --push_to_hub - - -python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 -python examples/scripts/online_dpo_llmjudge.py \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_llmjudge \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ - --total_episodes 30000 \ - --model_name_or_path EleutherAI/pythia-14m \ - --sft_model_path EleutherAI/pythia-14m \ - --reward_model_path EleutherAI/pythia-14m \ - --non_eos_penalty \ - --stop_token eos \ - --response_length 53 \ - --sanity_check \ - --base_url https://ip-26-0-166-125/v1 \ - --api_key token-abc123 \ - --model NousResearch/Meta-Llama-3-8B-Instruct -""" - - -@dataclass -class ScriptArguments: - dataset_name: str = None - dataset_text_field: str = "prompt" - dataset_train_split: str = "train" - dataset_test_split: Optional[str] = "validation" - max_length: int = 512 - - -def prepare_dataset(dataset, tokenizer, dataset_text_field): - """pre-tokenize the dataset before training; only collate during training""" - - def tokenize(element): - outputs = tokenizer( - element[dataset_text_field], - padding=False, - ) - return {"input_ids": outputs["input_ids"]} - - return dataset.map( - tokenize, - remove_columns=dataset.column_names, - batched=True, - num_proc=4, # multiprocessing.cpu_count(), - load_from_cache_file=False, - ) - - -TEMPLATE = r""" -Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence. - -### Post: -{{post}} - -### Summary A: -{{response0}} - -### Summary B: -{{response1}} - -### Instructions: -FIRST provide a one-sentence comparison of the two summaries, explaining which \ -you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format: -Comparison: -Preferred: <"A" or "B"> -""" - -# 1. extract the common text as the query (this way we do not require the user to provide the query) -# 2. support different kind of judges (let us say mainly LLM judges at the moment) - - -@dataclass -class LLMJudgeConfig: - n: int = 64 - model: str = "gpt-3.5-turbo-0125" - max_parallel_requests: Optional[int] = None - llm_judge_template: str = "" - base_url: Optional[str] = None - api_key: Optional[str] = None - - def __post_init__(self): - if "gpt-3.5" in self.model: - # gpt-3.5 generates so fast that it will exceeds the - # token limit per minute - self.max_parallel_requests = 11 - elif "gpt-4" in self.model: - self.max_parallel_requests = 13 - else: # assume self-hosted - self.max_parallel_requests = 11 - - -class LLMJudge: - def __init__(self, ljc: LLMJudgeConfig): - self.ljc = ljc - self.async_client = AsyncOpenAI(api_key=ljc.api_key, base_url=ljc.base_url) - - async def process_text(self, post: str, response0: str, response1: str, i: int, limiter=None): - text = self.ljc.llm_judge_template.replace("{{post}}", post) - text = text.replace("{{response0}}", response0) - text = text.replace("{{response1}}", response1) - - async with limiter: - response = None - while response is None: - try: - response = await self.async_client.chat.completions.create( - model=self.ljc.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": text}, - ], - ) - r = response.choices[0].message.content - except Exception as e: - print(f"error in {i}: {e}") - time.sleep(30) - continue - - try: - comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip() - preferred = r.split("Preferred:")[1].strip() - return comparison, preferred, i, text + r - except Exception as e: - print(f"error in {i} {e}") - return "", random.choice(["A", "B"]), i, text + r - - def judge(self, df: pd.DataFrame): - async def main(ljc: LLMJudgeConfig, df: pd.DataFrame): - limiter = asyncio.Semaphore(ljc.max_parallel_requests) - """`df` should have columns: `prompt`, `response0`, `response1`""" - tasks = [] - df["explanation"] = [None for _ in range(len(df))] - df["preferred"] = [None for _ in range(len(df))] - df["shuffled_index"] = [None for _ in range(len(df))] - df["entire_conversation"] = [None for _ in range(len(df))] - r = range(min(ljc.n, len(df))) - if ljc.n == -1: - r = range(len(df)) - for i in r: - post = df["prompt"].iloc[i].strip() - # shuffled the index to avoid GPT4's preference bias in the content's order - shuffled_index = random.randint(0, 1) - df.at[i, "shuffled_index"] = shuffled_index - responses = [ - df["response0"].iloc[i].strip(), - df["response1"].iloc[i].strip(), - ] - response0 = responses[shuffled_index] - response1 = responses[1 - shuffled_index] - task = asyncio.create_task(self.process_text(post, response0, response1, i, limiter)) - tasks.append(task) - - results = await tqdm_asyncio.gather(*tasks) - - for _, (comparison, preferred, i, entire_conversation) in enumerate(results): - df.at[i, "explanation"] = comparison - df.at[i, "entire_conversation"] = entire_conversation - preferred_label = ( - "response0" - if (df.at[i, "shuffled_index"] == 0 and preferred == "A") - or (df.at[i, "shuffled_index"] == 1 and preferred == "B") - else "response1" - ) - df.at[i, "preferred"] = preferred_label - return df - - return asyncio.run(main(self.ljc, df)) - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig, LLMJudgeConfig)) - args, config, model_config, judge_config = parser.parse_args_and_config() - - ################ - # Model & Tokenizer - ################ - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=True, - ) - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - judge_config.n = -1 - judge_config.llm_judge_template = TEMPLATE - judge = LLMJudge(judge_config) - ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) - policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path) - ################ - # Dataset - ################ - raw_datasets = load_dataset(args.dataset_name) - if config.sanity_check: - for key in raw_datasets: - raw_datasets[key] = raw_datasets[key].select(range(1024)) - train_dataset = raw_datasets[args.dataset_train_split] - train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field) - - if args.dataset_test_split is not None: - eval_dataset = raw_datasets[args.dataset_test_split] - eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field) - else: - eval_dataset = None - ################ - # Training - ################ - - trainer = OnlineDPOTrainer( - config=config, - tokenizer=tokenizer, - policy=policy, - ref_policy=ref_policy, - judge=judge, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) - trainer.train() - if not config.sanity_check: - trainer.save_model(config.output_dir) - if config.push_to_hub: - trainer.push_to_hub() - trainer.generate_completions() From c7680c7371b45607f21697fe2d09709184e93dbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 20:38:32 +0000 Subject: [PATCH 41/92] fix data collator pad token --- trl/trainer/online_dpo_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index d8a0eb297a2..01a1b9fd55f 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -87,8 +87,15 @@ def __init__( self.train_dataset_len = len(train_dataset) - default_collator = default_data_collator if tokenizer is None else DPODataCollatorWithPadding() - self.data_collator = data_collator if data_collator is not None else default_collator + # Define the collator + if data_collator is None: + if tokenizer is not None: + self.data_collator = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id) + else: # tokenizer is None + self.data_collator = default_data_collator + else: + self.data_collator = data_collator + self.optimizer, self.lr_scheduler = optimizers self.num_generation_per_prompt = 2 @@ -569,7 +576,6 @@ def repeat_generator(): metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() - print(responses.tolist()) metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode From d3d5175760f4cb35c5465065133e53025cc955f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 13 Aug 2024 21:11:24 +0000 Subject: [PATCH 42/92] add space --- trl/trainer/online_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 01a1b9fd55f..618b49ecdb8 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -195,6 +195,7 @@ def __init__( train_dataset = train_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) if eval_dataset is not None: eval_dataset = eval_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) + self.dataloader = DataLoader( train_dataset, batch_size=self.local_dataloader_batch_size, From 94f142e9405ed2b9f49492acf36335bd3626c4b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 16:53:16 +0000 Subject: [PATCH 43/92] fix pref score --- trl/trainer/online_dpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 618b49ecdb8..14c81f2ceeb 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -396,13 +396,18 @@ def repeat_generator(): prompts=data["prompt"], completions=completions ) # preferences is a list of prefered indexes preferences = torch.tensor(preferences, dtype=torch.float32, device=device) - # Get the number of invalid answers by counting the number of -1 in preferences + # Get the number of invalid answers by counting the number of -1 in preferences (just for logging) invalid_rate = (preferences == -1).sum() / len(preferences) # Replace invalid preferences with random preferences preferences = torch.where( preferences == -1, torch.randint(0, 2, preferences.shape, device=device), preferences ) - scores = torch.cat((preferences, 1 - preferences)) + # Convert preferences to scores + # The first half of the scores is the score of the first candidate. It's 1 when the first + # candidate is preferred, 0 otherwise. Since `preferences` is the index of the preferred candidate, + # the score of the first candidate is 1 - preferences. The score of the second candidate is the + # opposite of the score of the first candidate. + scores = torch.cat((1 - preferences, preferences)) del (logprob, ref_logprob, score) torch.cuda.empty_cache() From 5a0b4e97ca148dbba6fd62deb4c1ee11e410b329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 20:09:31 +0000 Subject: [PATCH 44/92] -1 for judges --- trl/trainer/judges.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index 0924ce61015..5207fc637c7 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -206,7 +206,7 @@ def get_rank(prompt, candidates): if response in ["0", "1"]: return int(response) else: - logging.debug(f"Invalid response from the judge model: '{response}'. Using random choice instead.") + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") return -1 # Call the completions concurrently @@ -252,11 +252,11 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order: if self.max_requests is not None and self.num_requests >= self.max_requests: if not self._warned: # Print the warning only once logging.warning( - f"Reached the maximum number of requests ({self.max_requests}). From now on, using random choice instead. " + f"Reached the maximum number of requests ({self.max_requests}). From now on, returning -1 instead. " " To increase the limit, set `max_requests` to a higher value, or to `None` for no limit." ) self._warned = True - return [random.choice([0, 1]) for _ in prompts] + return [-1] * len(prompts) # Shuffle the order of the completions to avoid positional bias if shuffle_order: @@ -272,7 +272,7 @@ def get_rank(prompt, candidates): if response in ["0", "1"]: return int(response) else: - logging.debug(f"Invalid response from the judge model: '{response}'. Using random choice instead.") + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") return -1 # Call the completions concurrently From ddac3b6815c7de6d7d62a7f600ebc48ac092a9b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 14 Aug 2024 20:09:48 +0000 Subject: [PATCH 45/92] self.model_wrapped = self.model --- trl/trainer/online_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 14c81f2ceeb..69472a3f34f 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -256,6 +256,7 @@ def train(self): accelerator = self.accelerator optimizer = self.optimizer model = self.model + self.model_wrapped = self.model ref_model = self.ref_model reward_model = self.reward_model tokenizer = self.tokenizer From 3e2cfe56224ababd7f81c06a933f34df114bf9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 12:41:35 +0000 Subject: [PATCH 46/92] onlinedpo inherits from training arguments --- trl/trainer/online_dpo_config.py | 53 +++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index a79e76fdf77..8a6bbd559a6 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -1,14 +1,53 @@ -import os from dataclasses import dataclass from typing import Literal, Optional -from trl.trainer.utils import OnPolicyConfig +from transformers import TrainingArguments @dataclass -class OnlineDPOConfig(OnPolicyConfig): - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """the name of this experiment""" +class OnlineDPOConfig(TrainingArguments): + # batch size related config + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + num_sample_generations: int = 10 + """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" + + # other config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: int = -1 + """the reward value for responses that do not contain `stop_token_id`""" + non_eos_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + sft_model_path: str = "EleutherAI/pythia-160m" + """the path to the sft model""" + + # various batch sizes + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_total_batches: Optional[int] = None + """The number of total batches to train""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + local_mini_batch_size: Optional[int] = None + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + reward_model_path: Optional[str] = None """the path to the reward model""" judge: Optional[str] = None @@ -23,3 +62,7 @@ class OnlineDPOConfig(OnPolicyConfig): disable_dropout: bool = True """whether to disable dropout of the model during training""" dataset_num_proc: Optional[int] = None + + + sanity_check: bool = False + """wether to run in debug mode""" \ No newline at end of file From 595c07e476121a10e9b64e5d0980cd0532aada61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 12:43:08 +0000 Subject: [PATCH 47/92] num_epoch -> num_steps_in_epochs --- trl/trainer/online_dpo_config.py | 5 ++--- trl/trainer/online_dpo_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 8a6bbd559a6..65d827eccd0 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -52,7 +52,7 @@ class OnlineDPOConfig(TrainingArguments): """the path to the reward model""" judge: Optional[str] = None - num_epochs: int = 4 + num_steps_in_epoch: int = 4 """the number of epochs to train""" beta: float = 0.05 @@ -63,6 +63,5 @@ class OnlineDPOConfig(TrainingArguments): """whether to disable dropout of the model during training""" dataset_num_proc: Optional[int] = None - sanity_check: bool = False - """wether to run in debug mode""" \ No newline at end of file + """wether to run in debug mode""" diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 69472a3f34f..1537b955e2e 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -279,7 +279,7 @@ def repeat_generator(): accelerator.print("===training policy===") start_time = time.time() - stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + stats_shape = (args.num_steps_in_epoch, args.num_mini_batches, args.gradient_accumulation_steps) loss_stats = torch.zeros(stats_shape, device=device) chosen_rewards_stats = torch.zeros(stats_shape, device=device) rejected_rewards_stats = torch.zeros(stats_shape, device=device) @@ -451,7 +451,7 @@ def repeat_generator(): torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for epoch_idx in range(args.num_epochs): + for epoch_idx in range(args.num_steps_in_epoch): b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) minibatch_idx = 0 for mini_batch_start in range( From 518c8963d3c6ed022bc7b9a0e26ea78602612c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 12:45:19 +0000 Subject: [PATCH 48/92] update -> epoch --- trl/trainer/online_dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 1537b955e2e..7a8986b10e0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -310,7 +310,7 @@ def repeat_generator(): self.state.save_steps = args.save_steps self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - for update in range(1, args.num_total_batches + 1): + for epoch in range(args.num_total_batches): self.state.episode += 1 * args.batch_size data = next(iter_dataloader) with torch.no_grad(): @@ -601,7 +601,7 @@ def repeat_generator(): torch.cuda.empty_cache() gc.collect() - if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + if args.num_sample_generations > 0 and epoch % self.sample_generations_freq == 0: self.generate_completions(sampling=True) # HF trainer specifics From 5c9fd95447db8e231654926d920b69e93fe0b04e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 12:57:05 +0000 Subject: [PATCH 49/92] epoch -> step; step_in_epoch -> ppo_epoch; rm run_name --- trl/trainer/online_dpo_trainer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7a8986b10e0..01785cbc3c8 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch.nn.functional as F from accelerate import Accelerator, PartialState -from accelerate.utils import broadcast, gather_object +from accelerate.utils import gather_object from datasets import Dataset from torch.utils.data import DataLoader from transformers import ( @@ -121,9 +121,6 @@ def __init__( args.num_total_batches = math.ceil( args.total_episodes / args.batch_size ) # we may train for more than `total_episodes` - time_tensor = torch.tensor(int(time.time()), device=accelerator.device) - time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes - args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" if args.run_name is None else args.run_name self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) @@ -279,7 +276,7 @@ def repeat_generator(): accelerator.print("===training policy===") start_time = time.time() - stats_shape = (args.num_steps_in_epoch, args.num_mini_batches, args.gradient_accumulation_steps) + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) loss_stats = torch.zeros(stats_shape, device=device) chosen_rewards_stats = torch.zeros(stats_shape, device=device) rejected_rewards_stats = torch.zeros(stats_shape, device=device) @@ -310,7 +307,7 @@ def repeat_generator(): self.state.save_steps = args.save_steps self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - for epoch in range(args.num_total_batches): + for step in range(args.num_total_batches): self.state.episode += 1 * args.batch_size data = next(iter_dataloader) with torch.no_grad(): @@ -451,7 +448,7 @@ def repeat_generator(): torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for epoch_idx in range(args.num_steps_in_epoch): + for epoch_idx in range(args.num_ppo_epochs): b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) minibatch_idx = 0 for mini_batch_start in range( @@ -601,7 +598,7 @@ def repeat_generator(): torch.cuda.empty_cache() gc.collect() - if args.num_sample_generations > 0 and epoch % self.sample_generations_freq == 0: + if args.num_sample_generations > 0 and step % self.sample_generations_freq == 0: self.generate_completions(sampling=True) # HF trainer specifics From 85c7bd5d51e9b9ba7636b8b85feb0afe15ceab4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 13:00:55 +0000 Subject: [PATCH 50/92] num_steps_in_epoch -> num_ppo_epochs --- trl/trainer/online_dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 65d827eccd0..f7d92e842fa 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -52,7 +52,7 @@ class OnlineDPOConfig(TrainingArguments): """the path to the reward model""" judge: Optional[str] = None - num_steps_in_epoch: int = 4 + num_ppo_epochs: int = 4 """the number of epochs to train""" beta: float = 0.05 From 2989a68074545c73f307a389acb472d2774e1a85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 13:01:22 +0000 Subject: [PATCH 51/92] epoch_idx -> ppo_epoch_idx --- trl/trainer/online_dpo_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 01785cbc3c8..6c75985e122 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -448,7 +448,7 @@ def repeat_generator(): torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for epoch_idx in range(args.num_ppo_epochs): + for ppo_epoch_idx in range(args.num_ppo_epochs): b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) minibatch_idx = 0 for mini_batch_start in range( @@ -531,18 +531,18 @@ def repeat_generator(): with torch.no_grad(): chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss chosen_rewards_stats[ - epoch_idx, minibatch_idx, gradient_accumulation_idx + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx ] = chosen_rewards.mean() rejected_rewards_stats[ - epoch_idx, minibatch_idx, gradient_accumulation_idx + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx ] = rejected_rewards.mean() chosen_logprobs_stats[ - epoch_idx, minibatch_idx, gradient_accumulation_idx + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx ] = chosen_logprobs_sum.mean() rejected_logprobs_stats[ - epoch_idx, minibatch_idx, gradient_accumulation_idx + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx ] = rejected_logprobs_sum.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 From 435bacd57a24431162c286385c2b8a4a3eb747c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 19 Aug 2024 13:23:12 +0000 Subject: [PATCH 52/92] make init consistent with dpo --- trl/trainer/online_dpo_trainer.py | 33 ++++++++++++++++--------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 6c75985e122..5dea0c3f028 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -16,9 +16,10 @@ from datasets import Dataset from torch.utils.data import DataLoader from transformers import ( - DataCollatorWithPadding, + DataCollator, GenerationConfig, - PreTrainedTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, Trainer, TrainerCallback, TrainerControl, @@ -28,10 +29,11 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, PrinterCallback -from trl.trainer.utils import DPODataCollatorWithPadding - from ..models.utils import unwrap_model_for_generation -from ..trainer.utils import ( +from .judges import BasePairwiseJudge +from .online_dpo_config import OnlineDPOConfig +from .utils import ( + DPODataCollatorWithPadding, OnlineTrainerState, batch_generation, disable_dropout_in_model, @@ -43,8 +45,6 @@ print_rich_table, truncate_response, ) -from .judges import BasePairwiseJudge -from .online_dpo_config import OnlineDPOConfig INVALID_LOGPROB = 1.0 @@ -53,20 +53,19 @@ class OnlineDPOTrainer(Trainer): def __init__( self, - model: nn.Module, - config: OnlineDPOConfig, - ref_model: nn.Module, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, reward_model: Optional[nn.Module] = None, judge: Optional[BasePairwiseJudge] = None, - data_collator: Optional[DataCollatorWithPadding] = None, + args: Optional[OnlineDPOConfig] = None, + data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - tokenizer: Optional[PreTrainedTokenizer] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), ) -> None: - self.args = config - args = config + self.args = args self.tokenizer = tokenizer # disable `pad_token_id` and `eos_token_id` because we just want to @@ -131,8 +130,8 @@ def __init__( ) # DPO logic: repeats the same prompt args.rloo_k times ### DPO stuff - self.beta = config.beta - self.loss_type = config.loss_type + self.beta = args.beta + self.loss_type = args.loss_type ######### # setup model, optimizer, and others @@ -558,6 +557,8 @@ def repeat_generator(): ) # fmt: on torch.cuda.empty_cache() + + # Log metrics with torch.no_grad(): mean_kl = kl.sum(1).mean() mean_entropy = (-logprobs).sum(1).mean() From 31d684dda5e98f5d737f7c1e00c009ce3001a3fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 20 Aug 2024 08:25:19 +0000 Subject: [PATCH 53/92] try another option --- examples/scripts/odpo.py | 38 ++++ examples/scripts/online_dpo.py | 15 +- trl/trainer/odpo.py | 287 +++++++++++++++++++++++++++++++ trl/trainer/online_dpo_config.py | 97 +++++------ 4 files changed, 384 insertions(+), 53 deletions(-) create mode 100644 examples/scripts/odpo.py create mode 100644 trl/trainer/odpo.py diff --git a/examples/scripts/odpo.py b/examples/scripts/odpo.py new file mode 100644 index 00000000000..567527bdd55 --- /dev/null +++ b/examples/scripts/odpo.py @@ -0,0 +1,38 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import ModelConfig +from trl.commands.cli_utils import TrlParser +from trl.trainer.odpo import OnlineDPOConfig, OnlineDPOTrainer + + +""" +python examples/scripts/online_dpo.py --output_dir online_dpo +""" + +if __name__ == "__main__": + parser = TrlParser((OnlineDPOConfig, ModelConfig)) + training_args, model_config = parser.parse_args_and_config() + + model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") + ref_model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") + reward_model = AutoModelForSequenceClassification.from_pretrained( + "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr" + ) + tokenizer = AutoTokenizer.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr", padding_side="left") + + dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") + + # For simplicity, we only use the first 1024 tokens + for split in dataset: + dataset[split] = dataset[split].select(range(1024)) + + trainer = OnlineDPOTrainer( + model=model, + ref_model=ref_model, + reward_model=reward_model, + args=training_args, + train_dataset=dataset["train"], + tokenizer=tokenizer, + ) + trainer.train() diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index e74658b307b..01f84e0afc1 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Optional -from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -67,17 +66,23 @@ class ScriptArguments: tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, padding_side="left", - trust_remote_code=True, + trust_remote_code=model_config.trust_remote_code, ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - ref_model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) - model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) + ref_model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + ) + model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + ) if config.reward_model_path is not None: - reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1) + reward_model = AutoModelForSequenceClassification.from_pretrained( + config.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code + ) else: reward_model = None diff --git a/trl/trainer/odpo.py b/trl/trainer/odpo.py new file mode 100644 index 00000000000..0e871b6f04d --- /dev/null +++ b/trl/trainer/odpo.py @@ -0,0 +1,287 @@ +import warnings +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import datasets +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from accelerate import PartialState +from datasets import Dataset +from packaging import version +from torch.utils.data import DataLoader, IterableDataset +from transformers import DataCollator, GenerationConfig, PreTrainedTokenizerBase, Trainer, TrainerCallback +from transformers.modeling_utils import PreTrainedModel +from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.training_args import OptimizerNames +from transformers.utils import ( + is_apex_available, + is_sagemaker_mp_enabled, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_npu_available, + is_torch_xpu_available, + logging, +) + +from ..models.utils import unwrap_model_for_generation +from .judges import BasePairwiseJudge +from .online_dpo_config import OnlineDPOConfig +from .utils import ( + DPODataCollatorWithPadding, + batch_generation, + first_true_indices, + forward, + get_reward, + prepare_deepspeed, + truncate_response, +) + + +if is_apex_available(): + from apex import amp + + +if is_sagemaker_mp_enabled(): + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import smp_forward_backward +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +logger = logging.get_logger(__name__) + + +class OnlineDPOTrainer(Trainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Union[PreTrainedModel, nn.Module] = None, + reward_model: Optional[nn.Module] = None, + judge: Optional[BasePairwiseJudge] = None, + args: Optional[OnlineDPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + self.ref_model = ref_model + + if reward_model is not None and judge is not None: + warnings.warn( + "Both `reward_model` and `judge` are provided. Please choose provide only one of them. " + "Ignoring `judge` and using `reward_model`." + ) + elif reward_model is None and judge is None: + raise ValueError("Either `reward_model` or `judge` must be provided.") + + self.reward_model = reward_model + self.judge = judge + + if args is None: + raise ValueError("`args` must be provided.") + + # Check that the tokenizer is provided + if tokenizer is None: + raise ValueError("`tokenizer` must be provided.") + + # We don't optimize the reward model model nor the ref model, so we can set them to eval mode + self.ref_model.eval() + if self.reward_model is not None: + self.reward_model.eval() + + # Define the collator is not provided + if data_collator is None: + data_collator = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + # Tokenize the dataset + fn_kwargs = {"is_encoder_decoder": model.config.is_encoder_decoder, "tokenizer": tokenizer} + train_dataset = train_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator + if self.is_deepspeed_enabled: + if self.reward_model is not None: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_model = prepare_deepspeed(self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16) + else: + self.ref_model = self.ref_model.to(self.accelerator.device) + if self.reward_model is not None: + self.reward_model = self.reward_model.to(self.accelerator.device) + + @staticmethod + def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> Dict[str, Any]: + """Tokenize a single row from a DPO specific dataset.""" + if not is_encoder_decoder: + batch = tokenizer(feature["prompt"], add_special_tokens=False) + # Add BOS token to head of prompt. Avoid adding if it's already there + if tokenizer.bos_token_id is not None: + prompt_len_input_ids = len(batch["input_ids"]) + if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: + batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] + batch["attention_mask"] = [1] + batch["attention_mask"] + else: + batch = tokenizer(feature["prompt"], add_special_tokens=True) + batch = {f"prompt_{key}": value for key, value in batch.items()} + return batch + + # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_train_dataloader) + def get_train_dataloader(self) -> DataLoader: + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + model.train() + + # Generate two completions + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + min_new_tokens=self.args.response_length, + temperature=(self.args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + query_responses, logits = batch_generation( + model=unwrapped_model, + queries=inputs["prompt_input_ids"], + local_rollout_forward_batch_size=8, + pad_token_id=self.tokenizer.pad_token_id, + generation_config=generation_config, + ) + context_length = inputs["prompt_input_ids"].shape[1] + responses = query_responses[:, context_length:] # responses.shape[1] == self.args.response_length + # Turn logits into logprobs + all_logprobs = F.log_softmax(logits, dim=-1) # (batch_size, response_length, vocab_size) + # Take the response tokens logprob + logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2) # (batch_size, response_length) + + # Same for the reference model + ref_output = forward(self.ref_model, query_responses, pad_token_id=self.tokenizer.pad_token_id) + # There is 1 offset, because the model predict the next token + ref_logits = ref_output.logits[:, context_length - 1 : -1] / generation_config.temperature + # Turn logits into logprobs + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + # Take the response tokens logprob + ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2) + + # Truncate response after the first occurrence of `stop_token_id`. + postprocessed_responses = truncate_response( + self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, responses + ) + # Reponses now look like: [123, 234, 345, EOS, PAD, PAD, ...] + + # Run reward model on the truncated responses + postprocessed_query_responses = torch.hstack((inputs["prompt_input_ids"], postprocessed_responses)) + _, scores, _ = get_reward( + self.reward_model, postprocessed_query_responses, self.tokenizer.pad_token_id, context_length + ) + + # Filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_eos_token = torch.any(postprocessed_query_responses == self.tokenizer.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + # response_idxs looks like tensor([[ 0, 1, 2, 3], [ 0, 1, 2, 3]]) + sequence_lengths = first_true_indices(postprocessed_responses == self.tokenizer.pad_token_id) + # The seq_len-th token is the EOS: [234, 345, EOS, PAD, PAD, ...] -> sequence_length = 2 + padding_mask = response_idxs > (sequence_lengths.unsqueeze(1) - 1) + # padding mask looks like ... + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index f7d92e842fa..d4f59dd7c18 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -6,62 +6,63 @@ @dataclass class OnlineDPOConfig(TrainingArguments): - # batch size related config - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - total_episodes: Optional[int] = None - """The total number of episodes in the dataset""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - num_sample_generations: int = 10 - """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" + # # batch size related config + # num_mini_batches: int = 1 + # """Number of minibatches to split a batch into""" + # total_episodes: Optional[int] = None + # """The total number of episodes in the dataset""" + # local_rollout_forward_batch_size: int = 64 + # """per rank no grad forward pass in the rollout phase""" + # num_sample_generations: int = 10 + # """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" - # other config + # # other config response_length: int = 53 - """the length of the response""" - stop_token: Optional[Literal["eos"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the truncation token id""" + # """the length of the response""" + # stop_token: Optional[Literal["eos"]] = None + # """the stop token""" + # stop_token_id: Optional[int] = None + # """the truncation token id""" temperature: float = 0.7 - """the sampling temperature""" - penalty_reward_value: int = -1 - """the reward value for responses that do not contain `stop_token_id`""" - non_eos_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" - sft_model_path: str = "EleutherAI/pythia-160m" - """the path to the sft model""" + missing_eos_penalty: Optional[float] = None + # """the sampling temperature""" + # penalty_reward_value: int = -1 + # """the reward value for responses that do not contain `stop_token_id`""" + # non_eos_penalty: bool = False + # """whether to penalize responses that do not contain `stop_token_id`""" + # sft_model_path: str = "EleutherAI/pythia-160m" + # """the path to the sft model""" - # various batch sizes - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - num_total_batches: Optional[int] = None - """The number of total batches to train""" - micro_batch_size: Optional[int] = None - """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" - local_batch_size: Optional[int] = None - """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" - batch_size: Optional[int] = None - """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" - local_mini_batch_size: Optional[int] = None - """the mini batch size per GPU""" - mini_batch_size: Optional[int] = None - """the mini batch size across GPUs""" + # # various batch sizes + # world_size: Optional[int] = None + # """The number of processes (GPUs) to use""" + # num_total_batches: Optional[int] = None + # """The number of total batches to train""" + # micro_batch_size: Optional[int] = None + # """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + # local_batch_size: Optional[int] = None + # """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + # batch_size: Optional[int] = None + # """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + # local_mini_batch_size: Optional[int] = None + # """the mini batch size per GPU""" + # mini_batch_size: Optional[int] = None + # """the mini batch size across GPUs""" - reward_model_path: Optional[str] = None - """the path to the reward model""" - judge: Optional[str] = None + # reward_model_path: Optional[str] = None + # """the path to the reward model""" + # judge: Optional[str] = None - num_ppo_epochs: int = 4 - """the number of epochs to train""" + # num_ppo_epochs: int = 4 + # """the number of epochs to train""" beta: float = 0.05 - """the entropy regularization coefficient of DPO""" + # """the entropy regularization coefficient of DPO""" loss_type: Literal["sigmoid", "ipo"] = "sigmoid" - """the type of loss to use for online DPO""" - disable_dropout: bool = True - """whether to disable dropout of the model during training""" + # """the type of loss to use for online DPO""" + # disable_dropout: bool = True + # """whether to disable dropout of the model during training""" dataset_num_proc: Optional[int] = None - sanity_check: bool = False - """wether to run in debug mode""" + # sanity_check: bool = False + # """wether to run in debug mode""" From 8fdbaa477c5be086e1ed8c681a03ae9fda3af5aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 20 Aug 2024 13:17:08 +0000 Subject: [PATCH 54/92] progress... --- trl/trainer/odpo.py | 146 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 140 insertions(+), 6 deletions(-) diff --git a/trl/trainer/odpo.py b/trl/trainer/odpo.py index 0e871b6f04d..446e07f49a3 100644 --- a/trl/trainer/odpo.py +++ b/trl/trainer/odpo.py @@ -202,8 +202,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, responses = query_responses[:, context_length:] # responses.shape[1] == self.args.response_length # Turn logits into logprobs all_logprobs = F.log_softmax(logits, dim=-1) # (batch_size, response_length, vocab_size) - # Take the response tokens logprob - logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2) # (batch_size, response_length) + # Take the response tokens logprob (batch_size, response_length) + logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2).squeeze(-1) # Same for the reference model ref_output = forward(self.ref_model, query_responses, pad_token_id=self.tokenizer.pad_token_id) @@ -212,7 +212,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, # Turn logits into logprobs ref_all_logprob = F.log_softmax(ref_logits, dim=-1) # Take the response tokens logprob - ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2) + ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2).squeeze(-1) # Truncate response after the first occurrence of `stop_token_id`. postprocessed_responses = truncate_response( @@ -239,9 +239,143 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, sequence_lengths = first_true_indices(postprocessed_responses == self.tokenizer.pad_token_id) # The seq_len-th token is the EOS: [234, 345, EOS, PAD, PAD, ...] -> sequence_length = 2 padding_mask = response_idxs > (sequence_lengths.unsqueeze(1) - 1) - # padding mask looks like ... - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + # With the above example, logprobs must look like [0.1, 0.2, 0.3, 1.0, 1.0, ...] + logprobs = logprobs.masked_fill(padding_mask, 1.0) + ref_logprobs = logprobs.masked_fill(padding_mask, 1.0) + + # Compute the rewards + kl = logprobs - ref_logprobs + non_score_reward = (-self.args.beta * kl).sum(1) + rlhf_reward = scores + non_score_reward + + # Split the scores in 2 + num_examples = scores.size(0) // 2 + first_half, second_half = scores.split(num_examples) + + # Get the indices of the chosen and rejected examples + num_examples_range = torch.arange(num_examples, device=scores.device) + mask = first_half >= second_half + chosen_indices = num_examples_range + (~mask * num_examples) + rejected_indices = num_examples_range + (mask * num_examples) + scores_margin = scores[chosen_indices] - scores[rejected_indices] + + + +#HERE + + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) + minibatch_idx = 0 + for mini_batch_start in range( + 0, + args.local_batch_size // self.num_generation_per_prompt, + args.local_mini_batch_size // self.num_generation_per_prompt, + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size // self.num_generation_per_prompt + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range( + 0, + args.local_mini_batch_size // self.num_generation_per_prompt, + args.per_device_train_batch_size, + ): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + ## chosen + chosen_mb_inds = chosen_indices[micro_batch_inds] + chosen_responses = responses[chosen_mb_inds] + + ## rejected + rejected_mb_inds = rejected_indices[micro_batch_inds] + rejected_responses = responses[rejected_mb_inds] + + concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) + concat_query_responses = query_responses[concat_mb_inds] + concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) + num_examples = chosen_mb_inds.shape[0] + chosen_logits = concat_output.logits[:num_examples] + rejected_logits = concat_output.logits[num_examples:] + + # chosen + chosen_logits = chosen_logits[:, context_length - 1 : -1] + chosen_logits /= args.temperature + 1e-7 + chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) + chosen_logprobs = torch.gather( + chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_logprobs = torch.masked_fill( + chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB + ) + chosen_ref_logprobs = ref_logprobs[chosen_mb_inds] + chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + # rejected + rejected_logits = rejected_logits[:, context_length - 1 : -1] + rejected_logits /= args.temperature + 1e-7 + rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) + rejected_logprobs = torch.gather( + rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_logprobs = torch.masked_fill( + rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB + ) + rejected_ref_logprobs = ref_logprobs[rejected_mb_inds] + rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) + rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( + 1 + ) + + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + loss = losses.mean() + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss + chosen_rewards_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = chosen_rewards.mean() + rejected_rewards_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = rejected_rewards.mean() + chosen_logprobs_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = chosen_logprobs_sum.mean() + rejected_logprobs_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = rejected_logprobs_sum.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + self.state.global_step += 1 + # del everything and empty cache + # fmt: off + del ( + loss, logits, + concat_output, concat_query_responses, + chosen_logits, rejected_logits, + chosen_logprobs, rejected_logprobs, + chosen_responses, rejected_responses, + ) + # fmt: on + torch.cuda.empty_cache() + + + inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): From b369beb2e3579bffd24c26401878b4d62159d276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 20 Aug 2024 21:48:51 +0000 Subject: [PATCH 55/92] odpo --- examples/scripts/odpo.py | 7 +- trl/trainer/odpo.py | 295 +++++++++++-------------------- trl/trainer/online_dpo_config.py | 103 +++++------ 3 files changed, 159 insertions(+), 246 deletions(-) diff --git a/examples/scripts/odpo.py b/examples/scripts/odpo.py index 567527bdd55..548672ea7a5 100644 --- a/examples/scripts/odpo.py +++ b/examples/scripts/odpo.py @@ -3,7 +3,8 @@ from trl import ModelConfig from trl.commands.cli_utils import TrlParser -from trl.trainer.odpo import OnlineDPOConfig, OnlineDPOTrainer +from trl.trainer.odpo import ODPOTrainer +from trl.trainer.online_dpo_config import ODPOConfig """ @@ -11,7 +12,7 @@ """ if __name__ == "__main__": - parser = TrlParser((OnlineDPOConfig, ModelConfig)) + parser = TrlParser((ODPOConfig, ModelConfig)) training_args, model_config = parser.parse_args_and_config() model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") @@ -27,7 +28,7 @@ for split in dataset: dataset[split] = dataset[split].select(range(1024)) - trainer = OnlineDPOTrainer( + trainer = ODPOTrainer( model=model, ref_model=ref_model, reward_model=reward_model, diff --git a/trl/trainer/odpo.py b/trl/trainer/odpo.py index 446e07f49a3..529874c6b4d 100644 --- a/trl/trainer/odpo.py +++ b/trl/trainer/odpo.py @@ -49,14 +49,13 @@ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - from transformers.trainer_pt_utils import smp_forward_backward else: IS_SAGEMAKER_MP_POST_1_10 = False logger = logging.get_logger(__name__) -class OnlineDPOTrainer(Trainer): +class ODPOTrainer(Trainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, @@ -181,6 +180,8 @@ def get_train_dataloader(self) -> DataLoader: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() + inputs = self._prepare_inputs(inputs) + # Generate two completions generation_config = GenerationConfig( max_new_tokens=self.args.response_length, @@ -190,202 +191,110 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, top_p=1.0, do_sample=True, ) - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - query_responses, logits = batch_generation( - model=unwrapped_model, - queries=inputs["prompt_input_ids"], - local_rollout_forward_batch_size=8, - pad_token_id=self.tokenizer.pad_token_id, - generation_config=generation_config, - ) - context_length = inputs["prompt_input_ids"].shape[1] - responses = query_responses[:, context_length:] # responses.shape[1] == self.args.response_length - # Turn logits into logprobs - all_logprobs = F.log_softmax(logits, dim=-1) # (batch_size, response_length, vocab_size) - # Take the response tokens logprob (batch_size, response_length) - logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2).squeeze(-1) - - # Same for the reference model - ref_output = forward(self.ref_model, query_responses, pad_token_id=self.tokenizer.pad_token_id) - # There is 1 offset, because the model predict the next token - ref_logits = ref_output.logits[:, context_length - 1 : -1] / generation_config.temperature - # Turn logits into logprobs - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - # Take the response tokens logprob - ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2).squeeze(-1) - - # Truncate response after the first occurrence of `stop_token_id`. - postprocessed_responses = truncate_response( - self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, responses - ) - # Reponses now look like: [123, 234, 345, EOS, PAD, PAD, ...] - - # Run reward model on the truncated responses - postprocessed_query_responses = torch.hstack((inputs["prompt_input_ids"], postprocessed_responses)) - _, scores, _ = get_reward( - self.reward_model, postprocessed_query_responses, self.tokenizer.pad_token_id, context_length - ) - - # Filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_eos_token = torch.any(postprocessed_query_responses == self.tokenizer.eos_token_id, dim=-1) - if self.args.missing_eos_penalty is not None: - scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - # response_idxs looks like tensor([[ 0, 1, 2, 3], [ 0, 1, 2, 3]]) - sequence_lengths = first_true_indices(postprocessed_responses == self.tokenizer.pad_token_id) - # The seq_len-th token is the EOS: [234, 345, EOS, PAD, PAD, ...] -> sequence_length = 2 - padding_mask = response_idxs > (sequence_lengths.unsqueeze(1) - 1) - # With the above example, logprobs must look like [0.1, 0.2, 0.3, 1.0, 1.0, ...] - logprobs = logprobs.masked_fill(padding_mask, 1.0) - ref_logprobs = logprobs.masked_fill(padding_mask, 1.0) - - # Compute the rewards - kl = logprobs - ref_logprobs - non_score_reward = (-self.args.beta * kl).sum(1) - rlhf_reward = scores + non_score_reward - - # Split the scores in 2 - num_examples = scores.size(0) // 2 - first_half, second_half = scores.split(num_examples) - - # Get the indices of the chosen and rejected examples - num_examples_range = torch.arange(num_examples, device=scores.device) - mask = first_half >= second_half - chosen_indices = num_examples_range + (~mask * num_examples) - rejected_indices = num_examples_range + (mask * num_examples) - scores_margin = scores[chosen_indices] - scores[rejected_indices] - - - -#HERE - - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) - minibatch_idx = 0 - for mini_batch_start in range( - 0, - args.local_batch_size // self.num_generation_per_prompt, - args.local_mini_batch_size // self.num_generation_per_prompt, - ): - mini_batch_end = mini_batch_start + args.local_mini_batch_size // self.num_generation_per_prompt - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range( - 0, - args.local_mini_batch_size // self.num_generation_per_prompt, - args.per_device_train_batch_size, - ): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - ## chosen - chosen_mb_inds = chosen_indices[micro_batch_inds] - chosen_responses = responses[chosen_mb_inds] - - ## rejected - rejected_mb_inds = rejected_indices[micro_batch_inds] - rejected_responses = responses[rejected_mb_inds] - - concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) - concat_query_responses = query_responses[concat_mb_inds] - concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) - num_examples = chosen_mb_inds.shape[0] - chosen_logits = concat_output.logits[:num_examples] - rejected_logits = concat_output.logits[num_examples:] - - # chosen - chosen_logits = chosen_logits[:, context_length - 1 : -1] - chosen_logits /= args.temperature + 1e-7 - chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) - chosen_logprobs = torch.gather( - chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) - ).squeeze(-1) - chosen_logprobs = torch.masked_fill( - chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB - ) - chosen_ref_logprobs = ref_logprobs[chosen_mb_inds] - chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) - chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) - - # rejected - rejected_logits = rejected_logits[:, context_length - 1 : -1] - rejected_logits /= args.temperature + 1e-7 - rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) - rejected_logprobs = torch.gather( - rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) - ).squeeze(-1) - rejected_logprobs = torch.masked_fill( - rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB - ) - rejected_ref_logprobs = ref_logprobs[rejected_mb_inds] - rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) - rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( - 1 - ) - - pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum - ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum - - logits = pi_logratios - ref_logratios - - if self.loss_type == "sigmoid": - losses = -F.logsigmoid(self.beta * logits) - elif self.loss_type == "ipo": - losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise NotImplementedError(f"invalid loss type {self.loss_type}") - - loss = losses.mean() - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss - chosen_rewards_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_rewards.mean() - rejected_rewards_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_rewards.mean() - chosen_logprobs_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_logprobs_sum.mean() - rejected_logprobs_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_logprobs_sum.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - self.state.global_step += 1 - # del everything and empty cache - # fmt: off - del ( - loss, logits, - concat_output, concat_query_responses, - chosen_logits, rejected_logits, - chosen_logprobs, rejected_logprobs, - chosen_responses, rejected_responses, + repeated_prompts = inputs["prompt_input_ids"].repeat(2, 1) + with torch.no_grad(): + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + query_responses, logits = batch_generation( + model=unwrapped_model, + queries=repeated_prompts, # generate 2 completions per prompt + local_rollout_forward_batch_size=self.args.per_device_train_batch_size, + pad_token_id=self.tokenizer.pad_token_id, + generation_config=generation_config, ) - # fmt: on - torch.cuda.empty_cache() - - - + num_examples, context_length = inputs["prompt_input_ids"].shape + responses = query_responses[:, context_length:] # responses.shape[1] == self.args.response_length + # Turn logits into logprobs + all_logprobs = F.log_softmax(logits, dim=-1) # (batch_size, response_length, vocab_size) + # Take the response tokens logprob (batch_size, response_length) + logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2).squeeze(-1) + + # Same for the reference model + ref_output = forward(self.ref_model, query_responses, pad_token_id=self.tokenizer.pad_token_id) + # There is 1 offset, because the model predict the next token + ref_logits = ref_output.logits[:, context_length - 1 : -1] / generation_config.temperature + # Turn logits into logprobs + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + # Take the response tokens logprob + ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2).squeeze(-1) + + # Truncate response after the first occurrence of `stop_token_id`. + processed_responses = truncate_response( + self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, responses + ) + # Reponses now look like: [123, 234, 345, EOS, PAD, PAD, ...] - inputs = self._prepare_inputs(inputs) - if is_sagemaker_mp_enabled(): - loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) - return loss_mb.reduce_mean().detach().to(self.args.device) + # Run reward model on the truncated responses + processed_query_responses = torch.hstack((repeated_prompts, processed_responses)) + _, scores, _ = get_reward( + self.reward_model, processed_query_responses, self.tokenizer.pad_token_id, context_length + ) - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) + # Filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_eos_token = torch.any(processed_responses == self.tokenizer.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + # response_idxs looks like tensor([[ 0, 1, 2, 3], [ 0, 1, 2, 3]]) + sequence_lengths = first_true_indices(processed_responses == self.tokenizer.pad_token_id) + # The seq_len-th token is the EOS: [234, 345, EOS, PAD, PAD, ...] -> sequence_length = 2 + padding_mask = response_idxs > (sequence_lengths.unsqueeze(1) - 1) + # With the above example, logprobs must look like [0.1, 0.2, 0.3, 1.0, 1.0, ...] + logprobs = logprobs.masked_fill(padding_mask, 1.0) + ref_logprobs = ref_logprobs.masked_fill(padding_mask, 1.0) + + # Compute the rewards + kl = logprobs - ref_logprobs + non_score_reward = (-self.args.beta * kl).sum(1) + rlhf_reward = scores + non_score_reward + mean_entropy = (-logprobs).sum(1).mean() + + # Split the scores in 2 + first_half, second_half = scores.split(num_examples) + + # Get the indices of the chosen and rejected examples + num_examples_range = torch.arange(num_examples, device=scores.device) + mask = first_half >= second_half + chosen_indices = num_examples_range + (~mask * num_examples) + rejected_indices = num_examples_range + (mask * num_examples) + scores_margin = scores[chosen_indices] - scores[rejected_indices] + + cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) + + cr_responses = responses[cr_indices] + cr_query_responses = query_responses[cr_indices] + cr_output = forward(model, cr_query_responses, self.tokenizer.pad_token_id) + cr_logits = cr_output.logits + cr_logits = cr_logits[:, context_length - 1 : -1] / generation_config.temperature + cr_all_log_probs = F.log_softmax(cr_logits, dim=-1) + cr_logprobs = torch.take_along_dim(cr_all_log_probs, cr_responses.unsqueeze(-1), dim=2) + cr_logprobs = cr_logprobs.squeeze(-1) + cr_logprobs = cr_logprobs.masked_fill(padding_mask[cr_indices], 1.0) + + cr_logprobs_sum = (cr_logprobs * ~padding_mask[cr_indices]).sum(1) + cr_ref_logprobs_sum = (ref_logprobs[cr_indices] * ~padding_mask[cr_indices]).sum(1) + + # Split the chosen and rejected examples + chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, num_examples) + chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, num_examples) + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.args.loss_type == "sigmoid": + losses = -F.logsigmoid(self.args.beta * logits) + elif self.args.loss_type == "ipo": + losses = (logits - 1 / (2 * self.args.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + loss = losses.mean() del inputs + if ( self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index d4f59dd7c18..a7d26724c26 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -6,63 +6,66 @@ @dataclass class OnlineDPOConfig(TrainingArguments): - # # batch size related config - # num_mini_batches: int = 1 - # """Number of minibatches to split a batch into""" - # total_episodes: Optional[int] = None - # """The total number of episodes in the dataset""" - # local_rollout_forward_batch_size: int = 64 - # """per rank no grad forward pass in the rollout phase""" - # num_sample_generations: int = 10 - # """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" + # batch size related config + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + num_sample_generations: int = 10 + """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" + """the length of the response""" + stop_token: Optional[Literal["eos"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + """the sampling temperature""" + penalty_reward_value: int = -1 + """the reward value for responses that do not contain `stop_token_id`""" + non_eos_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + sft_model_path: str = "EleutherAI/pythia-160m" + """the path to the sft model""" + # various batch sizes + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + num_total_batches: Optional[int] = None + """The number of total batches to train""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + local_mini_batch_size: Optional[int] = None + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + reward_model_path: Optional[str] = None + """the path to the reward model""" + judge: Optional[str] = None + """the type of loss to use for online DPO""" + disable_dropout: bool = True + """whether to disable dropout of the model during training""" + sanity_check: bool = False + """wether to run in debug mode""" # # other config response_length: int = 53 - # """the length of the response""" - # stop_token: Optional[Literal["eos"]] = None - # """the stop token""" - # stop_token_id: Optional[int] = None - # """the truncation token id""" temperature: float = 0.7 missing_eos_penalty: Optional[float] = None - # """the sampling temperature""" - # penalty_reward_value: int = -1 - # """the reward value for responses that do not contain `stop_token_id`""" - # non_eos_penalty: bool = False - # """whether to penalize responses that do not contain `stop_token_id`""" - # sft_model_path: str = "EleutherAI/pythia-160m" - # """the path to the sft model""" - - # # various batch sizes - # world_size: Optional[int] = None - # """The number of processes (GPUs) to use""" - # num_total_batches: Optional[int] = None - # """The number of total batches to train""" - # micro_batch_size: Optional[int] = None - # """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" - # local_batch_size: Optional[int] = None - # """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" - # batch_size: Optional[int] = None - # """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" - # local_mini_batch_size: Optional[int] = None - # """the mini batch size per GPU""" - # mini_batch_size: Optional[int] = None - # """the mini batch size across GPUs""" - - # reward_model_path: Optional[str] = None - # """the path to the reward model""" - # judge: Optional[str] = None + num_ppo_epochs: int = 4 + beta: float = 0.05 + loss_type: Literal["sigmoid", "ipo"] = "sigmoid" + dataset_num_proc: Optional[int] = None - # num_ppo_epochs: int = 4 - # """the number of epochs to train""" +@dataclass +class ODPOConfig(TrainingArguments): + response_length: int = 53 + temperature: float = 0.7 + missing_eos_penalty: Optional[float] = None beta: float = 0.05 - # """the entropy regularization coefficient of DPO""" loss_type: Literal["sigmoid", "ipo"] = "sigmoid" - # """the type of loss to use for online DPO""" - # disable_dropout: bool = True - # """whether to disable dropout of the model during training""" dataset_num_proc: Optional[int] = None - - # sanity_check: bool = False - # """wether to run in debug mode""" From d9c973616cff1ce83bf44650b94daba2914552b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 23 Aug 2024 22:26:57 +0000 Subject: [PATCH 56/92] current progress --- trl/trainer/odpo.py | 219 +++++++++++++++++-------------- trl/trainer/online_dpo_config.py | 2 +- 2 files changed, 118 insertions(+), 103 deletions(-) diff --git a/trl/trainer/odpo.py b/trl/trainer/odpo.py index 529874c6b4d..b0ff7e43aaa 100644 --- a/trl/trainer/odpo.py +++ b/trl/trainer/odpo.py @@ -20,7 +20,7 @@ is_sagemaker_mp_enabled, is_torch_mlu_available, is_torch_mps_available, - is_torch_musa_available, + # is_torch_musa_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -31,12 +31,9 @@ from .online_dpo_config import OnlineDPOConfig from .utils import ( DPODataCollatorWithPadding, - batch_generation, first_true_indices, - forward, get_reward, prepare_deepspeed, - truncate_response, ) @@ -55,6 +52,30 @@ logger = logging.get_logger(__name__) +def truncate_right(input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int): + """ + Truncates the input tensor from the right side after the first occurrence of the stop token. + + Args: + input_ids (`torch.Tensor`): + The tensor containing the responses to be truncated + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses + + Returns: + `torch.Tensor`: + The truncated responses tensor with pad tokens filled after the stop token. + """ + trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] + idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) + output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) + mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) + return output_ids, mask + + class ODPOTrainer(Trainer): def __init__( self, @@ -180,102 +201,94 @@ def get_train_dataloader(self) -> DataLoader: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() + # Sample 2 completations per prompt of size `completion_length` from the model inputs = self._prepare_inputs(inputs) - - # Generate two completions generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - min_new_tokens=self.args.response_length, - temperature=(self.args.temperature + 1e-7), + max_new_tokens=self.args.completion_length, + min_new_tokens=self.args.completion_length, + temperature=self.args.temperature, top_k=0.0, top_p=1.0, do_sample=True, ) - repeated_prompts = inputs["prompt_input_ids"].repeat(2, 1) - with torch.no_grad(): - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - query_responses, logits = batch_generation( - model=unwrapped_model, - queries=repeated_prompts, # generate 2 completions per prompt - local_rollout_forward_batch_size=self.args.per_device_train_batch_size, - pad_token_id=self.tokenizer.pad_token_id, - generation_config=generation_config, - ) - num_examples, context_length = inputs["prompt_input_ids"].shape - responses = query_responses[:, context_length:] # responses.shape[1] == self.args.response_length - # Turn logits into logprobs - all_logprobs = F.log_softmax(logits, dim=-1) # (batch_size, response_length, vocab_size) - # Take the response tokens logprob (batch_size, response_length) - logprobs = torch.take_along_dim(all_logprobs, responses.unsqueeze(-1), dim=2).squeeze(-1) - - # Same for the reference model - ref_output = forward(self.ref_model, query_responses, pad_token_id=self.tokenizer.pad_token_id) - # There is 1 offset, because the model predict the next token - ref_logits = ref_output.logits[:, context_length - 1 : -1] / generation_config.temperature - # Turn logits into logprobs - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - # Take the response tokens logprob - ref_logprobs = torch.take_along_dim(ref_all_logprob, responses.unsqueeze(-1), dim=2).squeeze(-1) - - # Truncate response after the first occurrence of `stop_token_id`. - processed_responses = truncate_response( - self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, responses + num_examples, context_length = inputs["prompt_input_ids"].shape + prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) + prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1) + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + output = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + generation_config=generation_config, ) - # Reponses now look like: [123, 234, 345, EOS, PAD, PAD, ...] + del inputs - # Run reward model on the truncated responses - processed_query_responses = torch.hstack((repeated_prompts, processed_responses)) + completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length + completion_ids, completion_mask = truncate_right( + completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id + ) + prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) + prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) + + # Get the logprobs of the completions from the model + output = model(prompt_completion_ids, attention_mask=prompt_completion_mask) + # There is 1 offset, because the model predict the next token + logits = output.logits[:, context_length - 1 : -1] + # Turn logits into logprobs + all_logprobs = F.log_softmax(logits, dim=-1) + # Take the completion tokens logprob + logprobs = torch.take_along_dim(all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) + del output, logits, all_logprobs # free memory + self.empty_cache() + + # Same for the reference model + with torch.no_grad(): + ref_output = self.ref_model(prompt_completion_ids, attention_mask=prompt_completion_mask) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) + ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) + del ref_output, ref_logits, ref_all_logprobs # free memory + self.empty_cache() + + # Get the reward from the reward model + with torch.no_grad(): _, scores, _ = get_reward( - self.reward_model, processed_query_responses, self.tokenizer.pad_token_id, context_length + self.reward_model, prompt_completion_ids, self.tokenizer.pad_token_id, context_length ) - # Filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_eos_token = torch.any(processed_responses == self.tokenizer.eos_token_id, dim=-1) - if self.args.missing_eos_penalty is not None: - scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - # response_idxs looks like tensor([[ 0, 1, 2, 3], [ 0, 1, 2, 3]]) - sequence_lengths = first_true_indices(processed_responses == self.tokenizer.pad_token_id) - # The seq_len-th token is the EOS: [234, 345, EOS, PAD, PAD, ...] -> sequence_length = 2 - padding_mask = response_idxs > (sequence_lengths.unsqueeze(1) - 1) - # With the above example, logprobs must look like [0.1, 0.2, 0.3, 1.0, 1.0, ...] - logprobs = logprobs.masked_fill(padding_mask, 1.0) - ref_logprobs = ref_logprobs.masked_fill(padding_mask, 1.0) - - # Compute the rewards - kl = logprobs - ref_logprobs - non_score_reward = (-self.args.beta * kl).sum(1) - rlhf_reward = scores + non_score_reward - mean_entropy = (-logprobs).sum(1).mean() - - # Split the scores in 2 - first_half, second_half = scores.split(num_examples) - - # Get the indices of the chosen and rejected examples - num_examples_range = torch.arange(num_examples, device=scores.device) - mask = first_half >= second_half - chosen_indices = num_examples_range + (~mask * num_examples) - rejected_indices = num_examples_range + (mask * num_examples) - scores_margin = scores[chosen_indices] - scores[rejected_indices] - - cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) - - cr_responses = responses[cr_indices] - cr_query_responses = query_responses[cr_indices] - cr_output = forward(model, cr_query_responses, self.tokenizer.pad_token_id) - cr_logits = cr_output.logits - cr_logits = cr_logits[:, context_length - 1 : -1] / generation_config.temperature - cr_all_log_probs = F.log_softmax(cr_logits, dim=-1) - cr_logprobs = torch.take_along_dim(cr_all_log_probs, cr_responses.unsqueeze(-1), dim=2) - cr_logprobs = cr_logprobs.squeeze(-1) - cr_logprobs = cr_logprobs.masked_fill(padding_mask[cr_indices], 1.0) - - cr_logprobs_sum = (cr_logprobs * ~padding_mask[cr_indices]).sum(1) - cr_ref_logprobs_sum = (ref_logprobs[cr_indices] * ~padding_mask[cr_indices]).sum(1) + # Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a low (fixed) score + contain_eos_token = torch.any(completion_ids == self.tokenizer.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) + + # Replace the logprobs of the padding tokens by 1.0 + padding_mask = ~completion_mask.bool() + logprobs = logprobs.masked_fill(padding_mask, 1.0) + ref_logprobs = ref_logprobs.masked_fill(padding_mask, 1.0) + + # Compute the rewards + kl = logprobs - ref_logprobs + non_score_reward = (-self.args.beta * kl).sum(1) + rlhf_reward = scores + non_score_reward + mean_entropy = (-logprobs).sum(1).mean() + + # Split the scores in 2 (the prompts of the first half are the same as the second half) + first_half, second_half = scores.split(num_examples) + + # Get the indices of the chosen and rejected examples + num_examples_range = torch.arange(num_examples, device=scores.device) + mask = first_half >= second_half + chosen_indices = num_examples_range + (~mask * num_examples) + rejected_indices = num_examples_range + (mask * num_examples) + scores_margin = scores[chosen_indices] - scores[rejected_indices] + + cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected + cr_logprobs = logprobs[cr_indices] + cr_ref_logprobs = ref_logprobs[cr_indices] + cr_padding_mask = padding_mask[cr_indices] + + cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) + cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) # Split the chosen and rejected examples chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, num_examples) @@ -293,24 +306,12 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, raise NotImplementedError(f"invalid loss type {self.loss_type}") loss = losses.mean() - del inputs if ( self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 ): - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_mlu_available(): - torch.mlu.empty_cache() - elif is_torch_musa_available(): - torch.musa.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() - else: - torch.cuda.empty_cache() + self.empty_cache() kwargs = {} @@ -328,3 +329,17 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps + + def empty_cache(self): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + # elif is_torch_musa_available(): + # torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index a7d26724c26..4e59945c35e 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -63,7 +63,7 @@ class OnlineDPOConfig(TrainingArguments): @dataclass class ODPOConfig(TrainingArguments): - response_length: int = 53 + completion_length: int = 53 temperature: float = 0.7 missing_eos_penalty: Optional[float] = None beta: float = 0.05 From 48c449acee75c6af6594d6208dec4e68f701e4ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 17:55:51 +0000 Subject: [PATCH 57/92] log and other changes --- examples/scripts/odpo.py | 47 +++++++++++++--- trl/trainer/odpo.py | 112 ++++++++++++++++++++++++++++++++++----- 2 files changed, 139 insertions(+), 20 deletions(-) diff --git a/examples/scripts/odpo.py b/examples/scripts/odpo.py index 548672ea7a5..b59c385176d 100644 --- a/examples/scripts/odpo.py +++ b/examples/scripts/odpo.py @@ -1,12 +1,50 @@ from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig +from transformers.integrations import WandbCallback from trl import ModelConfig from trl.commands.cli_utils import TrlParser -from trl.trainer.odpo import ODPOTrainer +from trl.trainer.odpo import ODPOTrainer, truncate_right from trl.trainer.online_dpo_config import ODPOConfig +class LogCompletionsCallback(WandbCallback): + def __init__(self, prompts, freq=None): + super().__init__() + self.inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True) + self.table = [] + self._last_logged_step = -1 + self.freq = freq + + def on_step(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + if state.global_step == self._last_logged_step: + return + freq = self.freq or args.save_steps + if state.global_step % freq != 0: + return + model = kwargs["model"] + tokenizer = kwargs["tokenizer"] + model.eval() + generation_config = GenerationConfig( + max_new_tokens=args.completion_length, min_new_tokens=args.completion_length + ) + inputs = self.inputs.to(args.device) + _, context_length = inputs["input_ids"].shape + output = model.generate(**inputs, generation_config=generation_config) + completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length + completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) + prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) + completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions)) + self.table.extend(data) + table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table) + self._wandb.log({"completions": table}) + self._last_logged_step = state.global_step + + """ python examples/scripts/online_dpo.py --output_dir online_dpo """ @@ -24,10 +62,6 @@ dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - # For simplicity, we only use the first 1024 tokens - for split in dataset: - dataset[split] = dataset[split].select(range(1024)) - trainer = ODPOTrainer( model=model, ref_model=ref_model, @@ -35,5 +69,6 @@ args=training_args, train_dataset=dataset["train"], tokenizer=tokenizer, + callbacks=[LogCompletionsCallback(dataset["test"]["prompt"][:8])], ) trainer.train() diff --git a/trl/trainer/odpo.py b/trl/trainer/odpo.py index b0ff7e43aaa..55fe12803f9 100644 --- a/trl/trainer/odpo.py +++ b/trl/trainer/odpo.py @@ -20,7 +20,6 @@ is_sagemaker_mp_enabled, is_torch_mlu_available, is_torch_mps_available, - # is_torch_musa_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -29,12 +28,7 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import ( - DPODataCollatorWithPadding, - first_true_indices, - get_reward, - prepare_deepspeed, -) +from .utils import DPODataCollatorWithPadding, first_true_indices, get_reward, prepare_deepspeed if is_apex_available(): @@ -132,6 +126,22 @@ def __init__( if eval_dataset is not None: eval_dataset = eval_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) + self.stats = { + "objective/kl": [], + "objective/entropy": [], + "objective/non_score_reward": [], + "objective/rlhf_reward": [], + "objective/scores": [], + "objective/scores_margin": [], + "rewards/chosen": [], + "rewards/rejected": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/contain_eos_token": [], + } + super().__init__( model=model, args=args, @@ -198,6 +208,55 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_eval_dataloader) + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} + + return self.accelerator.prepare(eval_dataloader) + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() @@ -266,12 +325,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, logprobs = logprobs.masked_fill(padding_mask, 1.0) ref_logprobs = ref_logprobs.masked_fill(padding_mask, 1.0) - # Compute the rewards - kl = logprobs - ref_logprobs - non_score_reward = (-self.args.beta * kl).sum(1) - rlhf_reward = scores + non_score_reward - mean_entropy = (-logprobs).sum(1).mean() - # Split the scores in 2 (the prompts of the first half are the same as the second half) first_half, second_half = scores.split(num_examples) @@ -280,7 +333,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, mask = first_half >= second_half chosen_indices = num_examples_range + (~mask * num_examples) rejected_indices = num_examples_range + (mask * num_examples) - scores_margin = scores[chosen_indices] - scores[rejected_indices] cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected cr_logprobs = logprobs[cr_indices] @@ -307,6 +359,38 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, loss = losses.mean() + # Log everything + self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) + self.stats["logps/chosen"].append(self.accelerator.gather(chosen_logprobs_sum).mean().item()) + self.stats["logps/rejected"].append(self.accelerator.gather(rejected_logprobs_sum).mean().item()) + self.stats["objective/scores"].append(self.accelerator.gather(scores.mean()).mean().item()) + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) + non_score_reward = (-self.args.beta * kl).sum(1) + mean_non_score_reward = non_score_reward.mean() + self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item()) + rlhf_reward = scores + non_score_reward + self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + mean_entropy = -logprobs.sum(1).mean() + self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) + scores_margin = scores[chosen_indices] - scores[rejected_indices] + self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item()) + chosen_rewards = self.args.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + gathered_chosen_rewards = self.accelerator.gather(chosen_rewards) + self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) + rejected_rewards = self.args.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + gathered_rejected_rewards = self.accelerator.gather(rejected_rewards) + self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) + margin = gathered_chosen_rewards - gathered_rejected_rewards + self.stats["rewards/margins"].append(margin.mean().item()) + accuracy = margin > 0 + self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) + + if self.state.global_step % self.args.logging_steps == 0: + self.log({key: sum(val) / len(val) for key, val in self.stats.items()}) + self.stats = {key: [] for key in self.stats} # reset stats + if ( self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 From 238ac5a77074cc065cfe3fa7ad28ae210f76e8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:49:31 +0000 Subject: [PATCH 58/92] rename for legacy --- trl/trainer/{online_dpo_trainer.py => online_dpo_trainer_.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename trl/trainer/{online_dpo_trainer.py => online_dpo_trainer_.py} (100%) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer_.py similarity index 100% rename from trl/trainer/online_dpo_trainer.py rename to trl/trainer/online_dpo_trainer_.py From 3a37e3c26d205e5d44214a0570057b0671f3ede4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:55:09 +0000 Subject: [PATCH 59/92] rename for legacy --- examples/scripts/{online_dpo.py => online_dpo_.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/scripts/{online_dpo.py => online_dpo_.py} (100%) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo_.py similarity index 100% rename from examples/scripts/online_dpo.py rename to examples/scripts/online_dpo_.py From 4d73ee3f57de35b3c82fdfe4a2fbda99ae85feba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:55:52 +0000 Subject: [PATCH 60/92] rename and move truncate --- .../{odpo.py => online_dpo_trainer.py} | 29 ++----------------- 1 file changed, 3 insertions(+), 26 deletions(-) rename trl/trainer/{odpo.py => online_dpo_trainer.py} (93%) diff --git a/trl/trainer/odpo.py b/trl/trainer/online_dpo_trainer.py similarity index 93% rename from trl/trainer/odpo.py rename to trl/trainer/online_dpo_trainer.py index 55fe12803f9..c02570ef563 100644 --- a/trl/trainer/odpo.py +++ b/trl/trainer/online_dpo_trainer.py @@ -28,7 +28,7 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import DPODataCollatorWithPadding, first_true_indices, get_reward, prepare_deepspeed +from .utils import DPODataCollatorWithPadding, truncate_right, get_reward, prepare_deepspeed if is_apex_available(): @@ -46,30 +46,6 @@ logger = logging.get_logger(__name__) -def truncate_right(input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int): - """ - Truncates the input tensor from the right side after the first occurrence of the stop token. - - Args: - input_ids (`torch.Tensor`): - The tensor containing the responses to be truncated - stop_token_id (`int`): - The token ID representing the stop token where truncation occurs - pad_token_id (`int`): - The token ID representing the pad token used to fill the truncated responses - - Returns: - `torch.Tensor`: - The truncated responses tensor with pad tokens filled after the stop token. - """ - trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) - new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] - idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) - output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) - mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) - return output_ids, mask - - class ODPOTrainer(Trainer): def __init__( self, @@ -334,6 +310,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, chosen_indices = num_examples_range + (~mask * num_examples) rejected_indices = num_examples_range + (mask * num_examples) + # Build tensor so that the first half is the chosen examples and the second half the rejected examples cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected cr_logprobs = logprobs[cr_indices] cr_ref_logprobs = ref_logprobs[cr_indices] @@ -387,7 +364,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, accuracy = margin > 0 self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) - if self.state.global_step % self.args.logging_steps == 0: + if self.state.global_step % self.state.logging_steps == 0: self.log({key: sum(val) / len(val) for key, val in self.stats.items()}) self.stats = {key: [] for key in self.stats} # reset stats From ba23435e3905e58579db72f5cdac12a03b67c80b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:56:15 +0000 Subject: [PATCH 61/92] rename --- examples/scripts/{odpo.py => online_dpo.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename examples/scripts/{odpo.py => online_dpo.py} (96%) diff --git a/examples/scripts/odpo.py b/examples/scripts/online_dpo.py similarity index 96% rename from examples/scripts/odpo.py rename to examples/scripts/online_dpo.py index b59c385176d..9073d3f011c 100644 --- a/examples/scripts/odpo.py +++ b/examples/scripts/online_dpo.py @@ -16,12 +16,12 @@ def __init__(self, prompts, freq=None): self._last_logged_step = -1 self.freq = freq - def on_step(self, args, state, control, **kwargs): + def on_step_end(self, args, state, control, **kwargs): if not state.is_world_process_zero: return if state.global_step == self._last_logged_step: return - freq = self.freq or args.save_steps + freq = self.freq or state.save_steps if state.global_step % freq != 0: return model = kwargs["model"] From 2db55aaa131b27c2301454ce89b168a71ddb2b3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:56:27 +0000 Subject: [PATCH 62/92] new config --- trl/trainer/online_dpo_config.py | 77 ++++++++++---------------------- 1 file changed, 23 insertions(+), 54 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 4e59945c35e..e141be6d0b5 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -6,63 +6,32 @@ @dataclass class OnlineDPOConfig(TrainingArguments): - # batch size related config - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - total_episodes: Optional[int] = None - """The total number of episodes in the dataset""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - num_sample_generations: int = 10 - """the number of debugging samples generations (i.e., `generate_completions` calls) throughout training""" - """the length of the response""" - stop_token: Optional[Literal["eos"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the truncation token id""" - """the sampling temperature""" - penalty_reward_value: int = -1 - """the reward value for responses that do not contain `stop_token_id`""" - non_eos_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" - sft_model_path: str = "EleutherAI/pythia-160m" - """the path to the sft model""" - # various batch sizes - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - num_total_batches: Optional[int] = None - """The number of total batches to train""" - micro_batch_size: Optional[int] = None - """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" - local_batch_size: Optional[int] = None - """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" - batch_size: Optional[int] = None - """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" - local_mini_batch_size: Optional[int] = None - """the mini batch size per GPU""" - mini_batch_size: Optional[int] = None - """the mini batch size across GPUs""" - reward_model_path: Optional[str] = None - """the path to the reward model""" - judge: Optional[str] = None - """the type of loss to use for online DPO""" - disable_dropout: bool = True - """whether to disable dropout of the model during training""" - sanity_check: bool = False - """wether to run in debug mode""" + r""" + Configuration class for the [`OnlineDPOTrainer`]. - # # other config - response_length: int = 53 - temperature: float = 0.7 - missing_eos_penalty: Optional[float] = None - num_ppo_epochs: int = 4 - beta: float = 0.05 - loss_type: Literal["sigmoid", "ipo"] = "sigmoid" - dataset_num_proc: Optional[int] = None + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + Args: + completion_length (`int`, *optional*, defaults to `53`): + Length of the completions to generate. + temperature (`float`, *optional*, defaults to `0.7`): + Temperature for sampling. + missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): + Penalty when the model fails to generate an EOS token. + beta (`float`, *optional*, defaults to `0.05`): + Beta parameter for the DPO loss. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of DPO loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + + dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): + Number of workers to use to process the data. + """ -@dataclass -class ODPOConfig(TrainingArguments): completion_length: int = 53 temperature: float = 0.7 missing_eos_penalty: Optional[float] = None From 49a7d4760f60b8f1eec163688ef70c29b7f2c5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:59:01 +0000 Subject: [PATCH 63/92] LogCompletionsCallback --- trl/trainer/callbacks.py | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 25491140066..6250f92fca1 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -30,10 +30,12 @@ TrainerState, TrainingArguments, ) +from transformers.integrations import WandbCallback from transformers.trainer_utils import has_length from ..models.utils import unwrap_model_for_generation from .judges import BaseRankJudge +from .utils import truncate_right if is_deepspeed_available(): @@ -233,3 +235,80 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra if self.trainer.accelerator.is_main_process: win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) self.trainer.log({"eval_win_rate": win_rate}) + + +class LogCompletionsCallback(WandbCallback): + r""" + A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases. + + Usage: + ```python + prompts = ["The capital of France is", "The opposite of up is"] + trainer = DPOTrainer(..., callbacks=[LogCompletionsCallback(prompts)]) + ``` + + Args: + prompts (`List[str]`): + The prompts to generate completions for. + freq (`Optional[int]`, *optional*, defaults to `None`): + The frequency at which to log completions. If not provided, defaults to `save_steps`. + """ + + def __init__(self, prompts: List[str], freq: int = None): + super().__init__() + self.prompts = prompts + self.inputs = None # will be tokenized in on_train_begin + self.table = [] + self._last_logged_step = -1 + self.freq = freq + + def on_train_begin(self, args, state, control, **kwargs): + tokenizer = kwargs["tokenizer"] + self.inputs = tokenizer(self.prompts, return_tensors="pt", padding=True, truncation=True) + + def on_step_end(self, args, state, control, **kwargs): + # Only log from the main process + if not state.is_world_process_zero: + return + + # Only log once per step (this method may be called multiple times) + if state.global_step == self._last_logged_step: + return + + # Only log every `freq` steps (if no `freq` is provided, log every `save_steps` steps) + freq = self.freq or state.save_steps + if state.global_step % freq != 0: + return + + # Get the model and tokenizer + model = kwargs["model"] + tokenizer = kwargs["tokenizer"] + model.eval() + + # Generate completions + generation_config = GenerationConfig( + max_new_tokens=args.completion_length, min_new_tokens=args.completion_length + ) + inputs = self.inputs.to(args.device) + _, context_length = inputs["input_ids"].shape + output = model.generate(**inputs, generation_config=generation_config) + + # Get only the completions + completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length + + # After the first EOS token, replace all tokens with padding tokens + completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) + + # Decode the prompts and completions + prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) + completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + + # Build the data to log + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions)) + self.table.extend(data) + table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table) + self._wandb.log({"completions": table}) + + # Save the last logged step, so we don't log the same completions multiple times + self._last_logged_step = state.global_step From f330c18f7bbb5e951b67d43a11207f58dd7dc4ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 19:59:12 +0000 Subject: [PATCH 64/92] style --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index c02570ef563..594962ab109 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -28,7 +28,7 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import DPODataCollatorWithPadding, truncate_right, get_reward, prepare_deepspeed +from .utils import DPODataCollatorWithPadding, get_reward, prepare_deepspeed, truncate_right if is_apex_available(): From a1d9ba367290821d9df10eeb51127c7717680f77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:01:38 +0000 Subject: [PATCH 65/92] rename trainer --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 594962ab109..6da5daf908a 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) -class ODPOTrainer(Trainer): +class OnlineDPOTrainer(Trainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, From dfebb9f075fedf76ca8fa70f9d0c6c4babfd90b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:01:49 +0000 Subject: [PATCH 66/92] truncate right in utils --- trl/trainer/utils.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 0bbff567c8d..7d3647cde7d 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1202,3 +1202,32 @@ def add_eos_token_if_needed( rejected_tokens["input_ids"].append(eos_token_id) rejected_tokens["attention_mask"].append(1) return chosen_tokens, rejected_tokens + + +def truncate_right( + input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Truncates the input tensor from the right side after the first occurrence of the stop token. + + Args: + input_ids (`torch.Tensor`): + The tensor containing the responses to be truncated + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses + + Returns: + tuple: + - `output_ids` (`torch.Tensor`): + The truncated responses tensor with pad tokens filled after the stop token + - `mask` (`torch.Tensor`): + The mask tensor to indicate the padding tokens + """ + trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] + idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) + output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) + mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) + return output_ids, mask From 4664b05730ff78202ec39520c8044f4a3f0aed21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:02:00 +0000 Subject: [PATCH 67/92] update example --- examples/scripts/online_dpo.py | 49 ++++------------------------------ 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 9073d3f011c..ca81f6d38be 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -1,48 +1,9 @@ from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig -from transformers.integrations import WandbCallback +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from trl import ModelConfig +from trl import ModelConfig, OnlineDPOConfig, OnlineDPOTrainer from trl.commands.cli_utils import TrlParser -from trl.trainer.odpo import ODPOTrainer, truncate_right -from trl.trainer.online_dpo_config import ODPOConfig - - -class LogCompletionsCallback(WandbCallback): - def __init__(self, prompts, freq=None): - super().__init__() - self.inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True) - self.table = [] - self._last_logged_step = -1 - self.freq = freq - - def on_step_end(self, args, state, control, **kwargs): - if not state.is_world_process_zero: - return - if state.global_step == self._last_logged_step: - return - freq = self.freq or state.save_steps - if state.global_step % freq != 0: - return - model = kwargs["model"] - tokenizer = kwargs["tokenizer"] - model.eval() - generation_config = GenerationConfig( - max_new_tokens=args.completion_length, min_new_tokens=args.completion_length - ) - inputs = self.inputs.to(args.device) - _, context_length = inputs["input_ids"].shape - output = model.generate(**inputs, generation_config=generation_config) - completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length - completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) - prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) - completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) - global_step = [str(state.global_step)] * len(prompts) - data = list(zip(global_step, prompts, completions)) - self.table.extend(data) - table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table) - self._wandb.log({"completions": table}) - self._last_logged_step = state.global_step +from trl.trainer.callbacks import LogCompletionsCallback """ @@ -50,7 +11,7 @@ def on_step_end(self, args, state, control, **kwargs): """ if __name__ == "__main__": - parser = TrlParser((ODPOConfig, ModelConfig)) + parser = TrlParser((OnlineDPOConfig, ModelConfig)) training_args, model_config = parser.parse_args_and_config() model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") @@ -62,7 +23,7 @@ def on_step_end(self, args, state, control, **kwargs): dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - trainer = ODPOTrainer( + trainer = OnlineDPOTrainer( model=model, ref_model=ref_model, reward_model=reward_model, From 9b808caa63dae2f4de5fbc0fa0c0c8861a0b59af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:48:46 +0000 Subject: [PATCH 68/92] reward model path --- trl/trainer/online_dpo_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index e141be6d0b5..eb0ffc92df8 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -14,6 +14,8 @@ class OnlineDPOConfig(TrainingArguments): command line. Args: + reward_model_path (`Optional[str]`, *optional*, defaults to `None`): + Path to the reward model. completion_length (`int`, *optional*, defaults to `53`): Length of the completions to generate. temperature (`float`, *optional*, defaults to `0.7`): @@ -32,6 +34,7 @@ class OnlineDPOConfig(TrainingArguments): Number of workers to use to process the data. """ + reward_model_path: Optional[str] = None completion_length: int = 53 temperature: float = 0.7 missing_eos_penalty: Optional[float] = None From 66ca5bdd047408c04566f61912ecdf73457cabd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:49:11 +0000 Subject: [PATCH 69/92] properly log --- trl/trainer/online_dpo_trainer.py | 39 +++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 6da5daf908a..f4ebbfe5496 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -364,10 +364,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, accuracy = margin > 0 self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) - if self.state.global_step % self.state.logging_steps == 0: - self.log({key: sum(val) / len(val) for key, val in self.stats.items()}) - self.stats = {key: [] for key in self.stats} # reset stats - if ( self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 @@ -391,6 +387,41 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss.detach() / self.args.gradient_accumulation_steps + # Same as Trainer.evaluate but log our metrics + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + logs["learning_rate"] = self._get_learning_rate() + + # Add our metrics + for key, val in self.stats.items(): + logs[key] = sum(val) / len(val) + self.stats = {key: [] for key in self.stats} # reset stats + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def empty_cache(self): if is_torch_xpu_available(): torch.xpu.empty_cache() From 3d280b9135580b62630979eef4e6c118004d92f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 24 Aug 2024 20:49:20 +0000 Subject: [PATCH 70/92] fix example --- examples/scripts/online_dpo.py | 69 +++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index ca81f6d38be..01ee87f4c16 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -1,7 +1,30 @@ +# flake8: noqa +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from trl import ModelConfig, OnlineDPOConfig, OnlineDPOTrainer +from trl import ( + DPOScriptArguments, + ModelConfig, + OnlineDPOConfig, + OnlineDPOTrainer, + get_kbit_device_map, + get_quantization_config, +) from trl.commands.cli_utils import TrlParser from trl.trainer.callbacks import LogCompletionsCallback @@ -11,25 +34,51 @@ """ if __name__ == "__main__": - parser = TrlParser((OnlineDPOConfig, ModelConfig)) - training_args, model_config = parser.parse_args_and_config() + parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig)) + args, training_args, model_config = parser.parse_args_and_config() + + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) - model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") - ref_model = AutoModelForCausalLM.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr") + model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) reward_model = AutoModelForSequenceClassification.from_pretrained( - "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr" + training_args.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=model_config.trust_remote_code, ) - tokenizer = AutoTokenizer.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr", padding_side="left") + dataset = load_dataset(args.dataset_name) - dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") + prompts = dataset[args.dataset_test_split]["prompt"][:8] + log_completions_callback = LogCompletionsCallback(prompts) trainer = OnlineDPOTrainer( model=model, ref_model=ref_model, reward_model=reward_model, args=training_args, - train_dataset=dataset["train"], + train_dataset=dataset[args.dataset_train_split], + eval_dataset=dataset[args.dataset_test_split], tokenizer=tokenizer, - callbacks=[LogCompletionsCallback(dataset["test"]["prompt"][:8])], + callbacks=[log_completions_callback], ) trainer.train() From 2d28ad9d60fe499ea251f1553792855cbb6d4715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 25 Aug 2024 06:29:49 +0000 Subject: [PATCH 71/92] add generation prompt and log special tokens --- examples/scripts/online_dpo.py | 11 ++++++++++- trl/trainer/callbacks.py | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 01ee87f4c16..8b18ccd850b 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -16,7 +16,7 @@ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer - +from accelerate import PartialState from trl import ( DPOScriptArguments, ModelConfig, @@ -68,6 +68,15 @@ ) dataset = load_dataset(args.dataset_name) + def process(row): + row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True) + return row + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + dataset = dataset.map(process, num_proc=training_args.dataset_num_proc) + prompts = dataset[args.dataset_test_split]["prompt"][:8] log_completions_callback = LogCompletionsCallback(prompts) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 6250f92fca1..27243c6cc02 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -300,8 +300,8 @@ def on_step_end(self, args, state, control, **kwargs): completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) # Decode the prompts and completions - prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) - completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False) + completions = tokenizer.batch_decode(completion_ids, skip_special_token=False) # Build the data to log global_step = [str(state.global_step)] * len(prompts) From 8f28f4f0665131530b85fe43b58c9b628cf1f004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 26 Aug 2024 10:41:52 +0000 Subject: [PATCH 72/92] true penalty --- trl/trainer/online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f4ebbfe5496..2b2764b6c6d 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -294,7 +294,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, # Completions not passing that filter will receive a low (fixed) score contain_eos_token = torch.any(completion_ids == self.tokenizer.eos_token_id, dim=-1) if self.args.missing_eos_penalty is not None: - scores = torch.where(contain_eos_token, scores, self.args.missing_eos_penalty) + scores[~contain_eos_token] -= self.args.missing_eos_penalty # Replace the logprobs of the padding tokens by 1.0 padding_mask = ~completion_mask.bool() From f936692e07cdd8a3c92c66e8cbd10f11e0033da4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 26 Aug 2024 17:13:13 +0000 Subject: [PATCH 73/92] defaults from the paper --- trl/trainer/online_dpo_config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index eb0ffc92df8..0c55c28006c 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -18,11 +18,11 @@ class OnlineDPOConfig(TrainingArguments): Path to the reward model. completion_length (`int`, *optional*, defaults to `53`): Length of the completions to generate. - temperature (`float`, *optional*, defaults to `0.7`): - Temperature for sampling. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): Penalty when the model fails to generate an EOS token. - beta (`float`, *optional*, defaults to `0.05`): + beta (`float`, *optional*, defaults to `0.1`): Beta parameter for the DPO loss. loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of DPO loss to use. Possible values are: @@ -36,8 +36,8 @@ class OnlineDPOConfig(TrainingArguments): reward_model_path: Optional[str] = None completion_length: int = 53 - temperature: float = 0.7 + temperature: float = 0.9 missing_eos_penalty: Optional[float] = None - beta: float = 0.05 + beta: float = 0.1 loss_type: Literal["sigmoid", "ipo"] = "sigmoid" dataset_num_proc: Optional[int] = None From 845c1bcdc7c03a620f02813940efe9c188c68d6f Mon Sep 17 00:00:00 2001 From: lewtun Date: Tue, 27 Aug 2024 12:29:31 +0200 Subject: [PATCH 74/92] Remove MPS (#1983) --- trl/trainer/online_dpo_trainer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 37fa10c14ea..594ba84f8f0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -19,7 +19,6 @@ is_apex_available, is_sagemaker_mp_enabled, is_torch_mlu_available, - is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -28,7 +27,13 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import DPODataCollatorWithPadding, get_reward, prepare_deepspeed, truncate_right +from .utils import ( + DPODataCollatorWithPadding, + get_reward, + prepare_deepspeed, + trl_sanitze_kwargs_for_tagging, + truncate_right, +) if is_apex_available(): @@ -429,12 +434,8 @@ def empty_cache(self): torch.xpu.empty_cache() elif is_torch_mlu_available(): torch.mlu.empty_cache() - # elif is_torch_musa_available(): - # torch.musa.empty_cache() elif is_torch_npu_available(): torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() else: torch.cuda.empty_cache() From a567e77f0ad2e345744cd9f114e927b7cc0c7c2d Mon Sep 17 00:00:00 2001 From: lewtun Date: Tue, 27 Aug 2024 13:52:28 +0200 Subject: [PATCH 75/92] Set KV cache false when gradient checkpointing is enabled (#1984) * Remove MPS * Fix --- trl/trainer/online_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 594ba84f8f0..08956a1aad7 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -252,6 +252,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, top_k=0.0, top_p=1.0, do_sample=True, + use_cache=False if self.args.gradient_checkpointing else True, ) num_examples, context_length = inputs["prompt_input_ids"].shape prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) From 21104fdbd1d56f32d9d8f8478ee9729d1049fd46 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 14:22:12 +0000 Subject: [PATCH 76/92] Various tweask --- docs/source/online_dpo_trainer.md | 47 +- .../scripts/{online_dpo.py => dpo_online.py} | 35 +- examples/scripts/online_dpo_.py | 126 ---- trl/trainer/callbacks.py | 6 +- trl/trainer/online_dpo_trainer_.py | 664 ------------------ 5 files changed, 56 insertions(+), 822 deletions(-) rename examples/scripts/{online_dpo.py => dpo_online.py} (77%) delete mode 100644 examples/scripts/online_dpo_.py delete mode 100644 trl/trainer/online_dpo_trainer_.py diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 88ab9b39987..80ca9128eb5 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -1,12 +1,12 @@ # Online DPO Trainer -TRL supports training LLMs with online DPO ([Guo et al., 2024](https://huggingface.co/papers/2402.04792)) with a reward model (RM). The idea of online DPO is to generate completions based on prompts and either have a reward model or an LLM judge to rank the responses as chosen or rejected. Then the model is updated with the ranked responses using the DPO loss. +TRL supports post-training LLMs with online DPO ([Guo et al., 2024](https://huggingface.co/papers/2402.04792)). The idea of online DPO is to generate completions per batch of prompts and have either a reward model or an LLM judge rank the responses as chosen or rejected. Then the model is updated with the ranked responses using the DPO loss. While [Guo et al. (2024)](https://huggingface.co/papers/2402.04792) used an LLM judge to score model completions, the current implementation only supports reward models -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. ## Get started -The basic API looks as follows: +The basic API is as follows: ```python from datasets import Dataset @@ -18,25 +18,23 @@ from transformers import ( ) NUM_DUMMY_SAMPLES = 100 tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") -tok.add_special_tokens({"pad_token": "[PAD]"}) +tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # The model to optimise model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") # The reference model to calculate the KL divergence against ref_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") -# The model to score completions with. In practice, you will need a fine-tuned reward model. +# The model to score completions with. In practice, you will need a reward model. reward_model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", num_labels=1) train_dataset = Dataset.from_dict( - {"input_ids": [tok.encode("Q: Hi how are you? A:")] * NUM_DUMMY_SAMPLES}) + {"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES}) eval_dataset = Dataset.from_dict( - {"input_ids": [tok.encode("Q: What do you like to eat A:")] * NUM_DUMMY_SAMPLES}) + {"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES}) trainer = OnlineDPOTrainer( - OnlineDPOConfig( - output_dir="online-dpo-model", - ), model=model, ref_model=ref_model, reward_model=reward_model, - tokenizer=tok, + args=OnlineDPOConfig(output_dir="online-dpo-model"), + tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, ) @@ -46,15 +44,28 @@ trainer.train() To run the online DPO script with a dummy reward model, run: ```bash -python examples/scripts/online_dpo.py \ +python examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --completion_length 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub + +python examples/scripts/dpo_online.py \ --dataset_name trl-lib/tldr \ --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo \ + --output_dir online-dpo-mpdel \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 64 \ --total_episodes 30000 \ --model_name_or_path EleutherAI/pythia-14m \ - --sft_model_path EleutherAI/pythia-14m \ --reward_model_path EleutherAI/pythia-14m \ --non_eos_penalty \ --stop_token eos \ @@ -67,12 +78,10 @@ python examples/scripts/online_dpo.py \ Unlike standard DPO where one provides a dataset with chosen and rejected columns, for online DPO one just needs a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will want to wrap your prompts in the messages format and then apply the chat template as follows: ```python -def prepare_dataset(dataset, tokenizer, dataset_prompt_field): - """pre-tokenize the dataset before training; only collate during training""" - return dataset.map( - lambda x: {"input_ids": tokenizer.apply_chat_template(x[dataset_prompt_field], add_generation_prompt=True)}, - remove_columns=dataset.column_names, - ) +def prepare_dataset(row): + """Apply chat template to messages""" + row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True) + return row dataset = prepare_dataset(dataset) ``` diff --git a/examples/scripts/online_dpo.py b/examples/scripts/dpo_online.py similarity index 77% rename from examples/scripts/online_dpo.py rename to examples/scripts/dpo_online.py index 8b18ccd850b..b738491cf1c 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/dpo_online.py @@ -12,6 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: + +python examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --completion_length 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub +""" import torch from datasets import load_dataset @@ -27,15 +44,12 @@ ) from trl.commands.cli_utils import TrlParser from trl.trainer.callbacks import LogCompletionsCallback - - -""" -python examples/scripts/online_dpo.py --output_dir online_dpo -""" +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE if __name__ == "__main__": parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig)) args, training_args, model_config = parser.parse_args_and_config() + args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( model_config.torch_dtype @@ -66,19 +80,20 @@ padding_side="left", trust_remote_code=model_config.trust_remote_code, ) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE + dataset = load_dataset(args.dataset_name) - def process(row): + def prepare_dataset(row): row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True) return row - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): - dataset = dataset.map(process, num_proc=training_args.dataset_num_proc) + dataset = dataset.map(prepare_dataset, num_proc=training_args.dataset_num_proc) prompts = dataset[args.dataset_test_split]["prompt"][:8] - log_completions_callback = LogCompletionsCallback(prompts) + log_completions_callback = LogCompletionsCallback(prompts, freq=training_args.logging_steps) trainer = OnlineDPOTrainer( model=model, diff --git a/examples/scripts/online_dpo_.py b/examples/scripts/online_dpo_.py deleted file mode 100644 index 01f84e0afc1..00000000000 --- a/examples/scripts/online_dpo_.py +++ /dev/null @@ -1,126 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from datasets import load_dataset -from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, -) - -from trl import HfPairwiseJudge, ModelConfig -from trl.commands.cli_utils import TrlParser -from trl.trainer import OnlineDPOConfig, OnlineDPOTrainer -from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE - - -""" -# Sanity check with minimal config and model -python examples/scripts/online_dpo.py \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --learning_rate 3e-6 \ - --output_dir online_dpo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 16 \ - --total_episodes 30000 \ - --model_name_or_path EleutherAI/pythia-14m \ - --judge hf_pairwise \ - --non_eos_penalty \ - --stop_token eos \ - --response_length 53 \ - --sanity_check - -accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ - examples/scripts/online_dpo.py \ - --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ - --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo \ - --per_device_train_batch_size 16 \ - --gradient_accumulation_steps 4 \ - --local_rollout_forward_batch_size 32 \ - --num_epochs 1 \ - --total_episodes 1000000 \ - --non_eos_penalty \ - --stop_token eos -""" - - -@dataclass -class ScriptArguments: - dataset_name: str = None - dataset_text_field: str = "prompt" - dataset_train_split: str = "train" - dataset_test_split: Optional[str] = "validation" - max_length: int = 512 - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) - args, config, model_config = parser.parse_args_and_config() - - ################ - # Model & Tokenizer - ################ - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, - ) - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - - ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code - ) - model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code - ) - - if config.reward_model_path is not None: - reward_model = AutoModelForSequenceClassification.from_pretrained( - config.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code - ) - else: - reward_model = None - - if config.judge is not None: - judge = HfPairwiseJudge() - else: - judge = None - - ################ - # Dataset - ################ - ds = load_dataset(args.dataset_name) - if config.sanity_check: - for key in ds: - ds[key] = ds[key].select(range(1024)) - train_dataset = ds[args.dataset_train_split] - if args.dataset_test_split is not None: - eval_dataset = ds[args.dataset_test_split] - else: - eval_dataset = None - - ################ - # Training - ################ - - trainer = OnlineDPOTrainer( - model=model, - config=config, - ref_model=ref_model, - reward_model=reward_model, - judge=judge, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - ) - trainer.train() - if not config.sanity_check: - trainer.save_model(config.output_dir) - if config.push_to_hub: - trainer.push_to_hub() - trainer.generate_completions() diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 27243c6cc02..4c5b5bbff38 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -251,7 +251,7 @@ class LogCompletionsCallback(WandbCallback): prompts (`List[str]`): The prompts to generate completions for. freq (`Optional[int]`, *optional*, defaults to `None`): - The frequency at which to log completions. If not provided, defaults to `save_steps`. + The frequency at which to log completions. If not provided, defaults to `logging_steps`. """ def __init__(self, prompts: List[str], freq: int = None): @@ -275,8 +275,8 @@ def on_step_end(self, args, state, control, **kwargs): if state.global_step == self._last_logged_step: return - # Only log every `freq` steps (if no `freq` is provided, log every `save_steps` steps) - freq = self.freq or state.save_steps + # Only log every `freq` steps (if no `freq` is provided, log every `logging_steps` steps) + freq = self.freq or state.logging_steps if state.global_step % freq != 0: return diff --git a/trl/trainer/online_dpo_trainer_.py b/trl/trainer/online_dpo_trainer_.py deleted file mode 100644 index 5dea0c3f028..00000000000 --- a/trl/trainer/online_dpo_trainer_.py +++ /dev/null @@ -1,664 +0,0 @@ -import gc -import math -import os -import time -import warnings -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from accelerate import Accelerator, PartialState -from accelerate.utils import gather_object -from datasets import Dataset -from torch.utils.data import DataLoader -from transformers import ( - DataCollator, - GenerationConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - Trainer, - TrainerCallback, - TrainerControl, - default_data_collator, -) -from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK -from transformers.trainer_callback import CallbackHandler, PrinterCallback - -from ..models.utils import unwrap_model_for_generation -from .judges import BasePairwiseJudge -from .online_dpo_config import OnlineDPOConfig -from .utils import ( - DPODataCollatorWithPadding, - OnlineTrainerState, - batch_generation, - disable_dropout_in_model, - exact_div, - first_true_indices, - forward, - get_reward, - prepare_deepspeed, - print_rich_table, - truncate_response, -) - - -INVALID_LOGPROB = 1.0 - - -class OnlineDPOTrainer(Trainer): - def __init__( - self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - reward_model: Optional[nn.Module] = None, - judge: Optional[BasePairwiseJudge] = None, - args: Optional[OnlineDPOConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - ) -> None: - self.args = args - self.tokenizer = tokenizer - - # disable `pad_token_id` and `eos_token_id` because we just want to - model.generation_config.eos_token_id = None - # generate tokens without truncation / padding - model.generation_config.pad_token_id = None - - self.ref_model = ref_model - self.reward_model = reward_model - self.judge = judge - if self.reward_model is not None and self.judge is not None: - warnings.warn( - "Both `reward_model` and `judge` are provided. Please choose provide only one of them. " - "Ignoring `judge` and using `reward_model`." - ) - elif self.reward_model is None and self.judge is None: - raise ValueError("Either `reward_model` or `judge` must be provided.") - - self.train_dataset_len = len(train_dataset) - - # Define the collator - if data_collator is None: - if tokenizer is not None: - self.data_collator = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id) - else: # tokenizer is None - self.data_collator = default_data_collator - else: - self.data_collator = data_collator - - self.optimizer, self.lr_scheduler = optimizers - self.num_generation_per_prompt = 2 - - ######### - # calculate various batch sizes - ######### - if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - self.accelerator = accelerator - args.world_size = accelerator.num_processes - args.local_batch_size = ( - args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches - ) - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.batch_size = int(args.local_batch_size * args.world_size) - args.mini_batch_size = exact_div( - args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" - ) - args.local_mini_batch_size = exact_div( - args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" - ) - args.num_total_batches = math.ceil( - args.total_episodes / args.batch_size - ) # we may train for more than `total_episodes` - self.local_seed = args.seed + accelerator.process_index * 100003 # Prime - if args.num_sample_generations > 0: - self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = exact_div( - args.local_batch_size, - self.num_generation_per_prompt, - "`local_batch_size` must be a multiple of `num_generation_per_prompt`", - ) # DPO logic: repeats the same prompt args.rloo_k times - - ### DPO stuff - self.beta = args.beta - self.loss_type = args.loss_type - - ######### - # setup model, optimizer, and others - ######### - if args.disable_dropout: - disable_dropout_in_model(model) - self.ref_model.eval() - if self.reward_model is not None: - self.reward_model.eval() - - if args.stop_token_id is None and args.stop_token and args.stop_token == "eos": - args.stop_token_id = tokenizer.eos_token_id - - self.model = model - self.create_optimizer_and_scheduler( - num_training_steps=args.num_total_batches - ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level - - ######### - ### trainer specifics - ######### - self.state = OnlineTrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - ) - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = CallbackHandler( - self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler - ) - self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self.control = TrainerControl() - self.current_flos = 0 - self.hp_search_backend = None - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # Create distant repo and output directory if needed - self.hub_model_id = None - if self.args.push_to_hub: - self.init_hf_repo() - if self.args.should_save: - os.makedirs(self.args.output_dir, exist_ok=True) - self.backup_model = None - - ######### - ### setup dataloader - ######### - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - # tokenize the dataset - fn_kwargs = { - "is_encoder_decoder": self.model.config.is_encoder_decoder, - "tokenizer": tokenizer, - } - train_dataset = train_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(self.tokenize_row, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc) - - self.dataloader = DataLoader( - train_dataset, - batch_size=self.local_dataloader_batch_size, - shuffle=True, - collate_fn=self.data_collator, - drop_last=True, # needed; otherwise the last batch will be of ragged shape - ) - # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` - # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c - torch.manual_seed(args.seed) - self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) - torch.manual_seed(self.local_seed) # reset the local seed again - - self.eval_dataloader = DataLoader( - eval_dataset, - batch_size=args.per_device_eval_batch_size, - collate_fn=self.data_collator, - drop_last=True, - ) # no need to shuffle eval dataset - self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - - if self.is_deepspeed_enabled: - if self.reward_model is not None: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - self.ref_model = prepare_deepspeed(self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16) - self.deepspeed = self.model - else: - self.ref_model = self.ref_model.to(self.accelerator.device) - if self.reward_model is not None: - self.reward_model = self.reward_model.to(self.accelerator.device) - - @staticmethod - def tokenize_row(feature, is_encoder_decoder, tokenizer) -> Dict: - """Tokenize a single row from a DPO specific dataset.""" - if not is_encoder_decoder: - batch = tokenizer(feature["prompt"], add_special_tokens=False) - # Add BOS token to head of prompt. Avoid adding if it's already there - prompt_len_input_ids = len(batch["input_ids"]) - if tokenizer.bos_token_id is not None: - if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: - batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] - batch["attention_mask"] = [1] + batch["attention_mask"] - else: - batch = tokenizer(feature["prompt"], add_special_tokens=True) - batch = {f"prompt_{key}": value for key, value in batch.items()} - return batch - - def get_train_dataloader(self) -> DataLoader: - return self.dataloader - - def get_eval_dataloader(self) -> DataLoader: - return self.eval_dataloader - - def train(self): - args = self.args - accelerator = self.accelerator - optimizer = self.optimizer - model = self.model - self.model_wrapped = self.model - ref_model = self.ref_model - reward_model = self.reward_model - tokenizer = self.tokenizer - dataloader = self.dataloader - device = accelerator.device - - def repeat_generator(): - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( - max_new_tokens=args.response_length, - min_new_tokens=args.response_length, - temperature=(args.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - accelerator.print("===training policy===") - start_time = time.time() - stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - loss_stats = torch.zeros(stats_shape, device=device) - chosen_rewards_stats = torch.zeros(stats_shape, device=device) - rejected_rewards_stats = torch.zeros(stats_shape, device=device) - chosen_logprobs_stats = torch.zeros(stats_shape, device=device) - rejected_logprobs_stats = torch.zeros(stats_shape, device=device) - model.train() - - # trainer state initialization - self.state.global_step = 0 - self.state.episode = 0 - self.state.max_steps = args.num_total_batches * args.num_mini_batches - self.state.num_train_epochs = args.total_episodes / self.train_dataset_len - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps is not None: - if args.logging_steps < 1: - self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) - else: - self.state.logging_steps = args.logging_steps - if args.eval_steps is not None: - if args.eval_steps < 1: - self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) - else: - self.state.eval_steps = args.eval_steps - if args.save_steps is not None: - if args.save_steps < 1: - self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) - else: - self.state.save_steps = args.save_steps - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - - for step in range(args.num_total_batches): - self.state.episode += 1 * args.batch_size - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["prompt_input_ids"].to(device) - queries = queries.repeat(self.num_generation_per_prompt, 1) - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - query_responses, logitss = batch_generation( - unwrapped_model, - queries, - args.local_rollout_forward_batch_size, - tokenizer.pad_token_id, - generation_config, - ) - - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_model, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) - - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - if reward_model is not None: - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - else: - score = None - - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - if score is not None: - scores.append(score) - - # stack all the tensors - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - if score is not None: - scores = torch.cat(scores, 0) - - if self.judge is not None: - num_examples = postprocessed_responses.size(0) // 2 - candidates_0 = tokenizer.batch_decode( - postprocessed_responses[:num_examples], skip_special_tokens=True - ) - candidates_1 = tokenizer.batch_decode( - postprocessed_responses[num_examples:], skip_special_tokens=True - ) - completions = [[c0, c1] for c0, c1 in zip(candidates_0, candidates_1)] - preferences = self.judge.judge( - prompts=data["prompt"], completions=completions - ) # preferences is a list of prefered indexes - preferences = torch.tensor(preferences, dtype=torch.float32, device=device) - # Get the number of invalid answers by counting the number of -1 in preferences (just for logging) - invalid_rate = (preferences == -1).sum() / len(preferences) - # Replace invalid preferences with random preferences - preferences = torch.where( - preferences == -1, torch.randint(0, 2, preferences.shape, device=device), preferences - ) - # Convert preferences to scores - # The first half of the scores is the score of the first candidate. It's 1 when the first - # candidate is preferred, 0 otherwise. Since `preferences` is the index of the preferred candidate, - # the score of the first candidate is 1 - preferences. The score of the second candidate is the - # opposite of the score of the first candidate. - scores = torch.cat((1 - preferences, preferences)) - - del (logprob, ref_logprob, score) - torch.cuda.empty_cache() - gc.collect() - - # Response Processing 3. filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) - if args.non_eos_penalty: - scores = torch.where(contain_eos_token, scores, args.penalty_reward_value) - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = (-args.beta * kl).sum(1) - rlhf_reward = scores + non_score_reward - - # num_examples should be same as args.local_batch_size divided by 2 - num_examples = scores.size(0) // 2 - first_half = scores[:num_examples] - second_half = scores[num_examples:] - - num_examples_range = torch.arange(num_examples).to(scores.device) - - chosen_indices = torch.where( - first_half >= second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples - ) - rejected_indices = torch.where( - first_half < second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples - ) - - scores_margin = scores[chosen_indices] - scores[rejected_indices] - torch.cuda.empty_cache() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size // self.num_generation_per_prompt) - minibatch_idx = 0 - for mini_batch_start in range( - 0, - args.local_batch_size // self.num_generation_per_prompt, - args.local_mini_batch_size // self.num_generation_per_prompt, - ): - mini_batch_end = mini_batch_start + args.local_mini_batch_size // self.num_generation_per_prompt - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range( - 0, - args.local_mini_batch_size // self.num_generation_per_prompt, - args.per_device_train_batch_size, - ): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - ## chosen - chosen_mb_inds = chosen_indices[micro_batch_inds] - chosen_responses = responses[chosen_mb_inds] - - ## rejected - rejected_mb_inds = rejected_indices[micro_batch_inds] - rejected_responses = responses[rejected_mb_inds] - - concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) - concat_query_responses = query_responses[concat_mb_inds] - concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) - num_examples = chosen_mb_inds.shape[0] - chosen_logits = concat_output.logits[:num_examples] - rejected_logits = concat_output.logits[num_examples:] - - # chosen - chosen_logits = chosen_logits[:, context_length - 1 : -1] - chosen_logits /= args.temperature + 1e-7 - chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) - chosen_logprobs = torch.gather( - chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) - ).squeeze(-1) - chosen_logprobs = torch.masked_fill( - chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB - ) - chosen_ref_logprobs = ref_logprobs[chosen_mb_inds] - chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) - chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) - - # rejected - rejected_logits = rejected_logits[:, context_length - 1 : -1] - rejected_logits /= args.temperature + 1e-7 - rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) - rejected_logprobs = torch.gather( - rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) - ).squeeze(-1) - rejected_logprobs = torch.masked_fill( - rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB - ) - rejected_ref_logprobs = ref_logprobs[rejected_mb_inds] - rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) - rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( - 1 - ) - - pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum - ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum - - logits = pi_logratios - ref_logratios - - if self.loss_type == "sigmoid": - losses = -F.logsigmoid(self.beta * logits) - elif self.loss_type == "ipo": - losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise NotImplementedError(f"invalid loss type {self.loss_type}") - - loss = losses.mean() - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss - chosen_rewards_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_rewards.mean() - rejected_rewards_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_rewards.mean() - chosen_logprobs_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = chosen_logprobs_sum.mean() - rejected_logprobs_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = rejected_logprobs_sum.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - self.state.global_step += 1 - # del everything and empty cache - # fmt: off - del ( - loss, logits, - concat_output, concat_query_responses, - chosen_logits, rejected_logits, - chosen_logprobs, rejected_logprobs, - chosen_responses, rejected_responses, - ) - # fmt: on - torch.cuda.empty_cache() - - # Log metrics - with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.mean() - eps = int(self.state.episode / (time.time() - start_time)) - g_chosen_reward = self.accelerator.gather(chosen_rewards_stats) - g_rejected_reward = self.accelerator.gather(rejected_rewards_stats) - metrics = {} - metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["objective/scores_margin"] = self.accelerator.gather(scores_margin.mean()).mean().item() - metrics["rewards/chosen"] = g_chosen_reward.mean().item() - metrics["rewards/rejected"] = g_rejected_reward.mean().item() - metrics["rewards/accuracies"] = (g_chosen_reward > g_rejected_reward).float().mean().item() - metrics["rewards/margins"] = (g_chosen_reward - g_rejected_reward).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() - metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats).mean().item() - metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats).mean().item() - metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() - metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = self.state.episode - if self.judge is not None: - metrics["judge/invalid_rate"] = invalid_rate.item() - self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log - self.state.global_step += 1 - self.log(metrics) - del (kl, mean_kl, mean_entropy, scores, scores_margin) - - self.lr_scheduler.step() - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - torch.cuda.empty_cache() - gc.collect() - - if args.num_sample_generations > 0 and step % self.sample_generations_freq == 0: - self.generate_completions(sampling=True) - - # HF trainer specifics - self.control = self.callback_handler.on_train_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - def generate_completions(self, sampling: bool = False): - args = self.args - tokenizer = self.tokenizer - generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - temperature=(0.01 + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - for batch in self.eval_dataloader: - query = batch["prompt_input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - query_response, _ = batch_generation( - unwrapped_model, - query, - query.shape[0], - tokenizer.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) - query_text = tokenizer.batch_decode(query, skip_special_tokens=True) - postprocessed_response_text = tokenizer.batch_decode( - postprocessed_response, skip_special_tokens=True - ) - table["Query"].extend(gather_object(query_text)) - table["Model response"].extend(gather_object(postprocessed_response_text)) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - if self.reward_model is not None: - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - - if sampling: - break - df = pd.DataFrame(table) - if self.accelerator.process_index == 0: - print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) From 8a033c7d9bcf176b30157982e8a9afee89ced6eb Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 14:42:16 +0000 Subject: [PATCH 77/92] Remove padding from table --- examples/scripts/dpo_online.py | 2 +- trl/trainer/callbacks.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index b738491cf1c..9f79cd789ac 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -82,7 +82,7 @@ ) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - + dataset = load_dataset(args.dataset_name) def prepare_dataset(row): diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 4c5b5bbff38..d3f9a10ec7b 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -300,8 +300,14 @@ def on_step_end(self, args, state, control, **kwargs): completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) # Decode the prompts and completions - prompts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False) - completions = tokenizer.batch_decode(completion_ids, skip_special_token=False) + prompts = [ + p.replace(tokenizer.pad_token, "") + for p in tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False) + ] + completions = [ + c.replace(tokenizer.pad_token, "") + for c in tokenizer.batch_decode(completion_ids, skip_special_token=False) + ] # Build the data to log global_step = [str(state.global_step)] * len(prompts) From 4fd0666fb02f1ecbf3c978a665fdd17d4e4bba2e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 15:43:28 +0000 Subject: [PATCH 78/92] Clean up --- docs/source/online_dpo_trainer.md | 252 +++++++++--------------------- trl/trainer/callbacks.py | 6 +- trl/trainer/online_dpo_config.py | 6 +- trl/trainer/online_dpo_trainer.py | 8 +- 4 files changed, 86 insertions(+), 186 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 80ca9128eb5..3847fb9c2ee 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -41,7 +41,7 @@ trainer = OnlineDPOTrainer( trainer.train() ``` -To run the online DPO script with a dummy reward model, run: +To test the online DPO script with 1B parameter models, run: ```bash python examples/scripts/dpo_online.py \ @@ -50,27 +50,13 @@ python examples/scripts/dpo_online.py \ --dataset_name trl-lib/tldr \ --learning_rate 5.0e-7 \ --output_dir pythia-1b-tldr-online-dpo \ - --per_device_train_batch_size 8 \ - --gradient_accumulation_steps 2 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ --num_train_epochs 3 \ --completion_length 53 \ --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 \ --push_to_hub - -python examples/scripts/dpo_online.py \ - --dataset_name trl-lib/tldr \ - --learning_rate 3e-6 \ - --output_dir online-dpo-mpdel \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ - --total_episodes 30000 \ - --model_name_or_path EleutherAI/pythia-14m \ - --reward_model_path EleutherAI/pythia-14m \ - --non_eos_penalty \ - --stop_token eos \ - --response_length 53 \ - --sanity_check ``` ## Expected dataset format @@ -90,7 +76,6 @@ dataset = prepare_dataset(dataset) The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) -* `eps`: Tracks the number of episodes per second. * `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model. * `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model. * `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. @@ -103,9 +88,8 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. * `logps/chosen`: The mean log probabilities of the chosen completions. * `logps/rejected`: The mean log probabilities of the rejected completions. -* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `val/contain_eos+token`: The fraction of completions which contain and EOS token. * `lr`: lr: The current learning rate used by the optimizer. -* `episode`: episode: The current global step or episode count in the training process. ## Cookbook @@ -117,174 +101,89 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. -* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. +* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty. This can help the model learn to generate more coherent completions. ## What is my model doing exactly? -To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations. +To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of generations. -![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif) - - -In the logs the sampled generations look like - -``` -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ -┃ query ┃ model response ┃ score ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ -│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │ -│ │ I don't know how to get rid of │ │ -│ TITLE: How do you get someone │ those feelings. I'm │ │ -│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │ -│ │ │ │ -│ POST: Hi, │ │ │ -│ I'm 22, and I have been with my │ │ │ -│ girlfriend for 5 years now. We │ │ │ -│ recently moved together. We've │ │ │ -│ always loved each other │ │ │ -│ intensely. │ │ │ -│ │ │ │ -│ Problem, I recently started to │ │ │ -│ have feelings for an other │ │ │ -│ person (a friend). This person │ │ │ -│ has had a boyfriend for now 3 │ │ │ -│ years, and has absolutely no │ │ │ -│ ideas. Those feelings were so │ │ │ -│ strong, it was hard to hide │ │ │ -│ them. After 2 months of me │ │ │ -│ being distant and really sad, │ │ │ -│ my girlfriend forced me to say │ │ │ -│ what was bothering me. I'm not │ │ │ -│ a good liar, and now she knows. │ │ │ -│ │ │ │ -│ We decided to give us a week │ │ │ -│ alone, I went to my parents. │ │ │ -│ │ │ │ -│ Now, I'm completely lost. I │ │ │ -│ keep on thinking about this │ │ │ -│ person, and I hate that. I │ │ │ -│ would like for those feelings │ │ │ -│ to go away, to leave me alone. │ │ │ -│ But I can't. │ │ │ -│ │ │ │ -│ What do I do? It's been 3 │ │ │ -│ months now, and I'm just │ │ │ -│ desperate. │ │ │ -│ │ │ │ -│ TL;DR: │ │ │ -├─────────────────────────────────┼─────────────────────────────────┼──────────┤ -│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │ -│ │ TV. I blasted Gangnam Style on │ │ -│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │ -│ with a loud TV. │ up as high as it could │ │ -│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │ -│ POST: She was in her living │ │ │ -│ room, watching TV. This was at │ │ │ -│ about 8:30 in the morning, and │ │ │ -│ she was exercising. She turned │ │ │ -│ the TV up extra loud to hear it │ │ │ -│ over her excercycle, and woke │ │ │ -│ me up. I went in there asking │ │ │ -│ for her to turn it down. She │ │ │ -│ said she didn't have to; I │ │ │ -│ explained that I always used │ │ │ -│ headphones so she didn't have │ │ │ -│ to deal with my noise and that │ │ │ -│ she should give me a little │ │ │ -│ more respect, given that I paid │ │ │ -│ rent at the time. │ │ │ -│ │ │ │ -│ She disagreed. I went back to │ │ │ -│ my room, rather pissed off at │ │ │ -│ the lack of equality. I had no │ │ │ -│ lock on my door; but I had a │ │ │ -│ dresser right next to it, so I │ │ │ -│ pulled one of the drawers out │ │ │ -│ enough so that it caused the │ │ │ -│ door to not be openable. Then, │ │ │ -│ I turned my speakers up really │ │ │ -│ loud and blasted Gangnam Style │ │ │ -│ on repeat, with the bass │ │ │ -│ cranked up as high as it could │ │ │ -│ go. │ │ │ -│ │ │ │ -│ If you hate Gangnam Style for │ │ │ -│ being overplayed, you will see │ │ │ -│ why I chose that particular │ │ │ -│ song. I personally don't mind │ │ │ -│ it. But here's the thing about │ │ │ -│ my bass; it vibrates the walls, │ │ │ -│ making one hell of a lot of │ │ │ -│ noise. Needless to say, my mom │ │ │ -│ was not pleased and shut off │ │ │ -│ the internet. But it was oh so │ │ │ -│ worth it. │ │ │ -│ │ │ │ -│ TL;DR: │ │ │ -└─────────────────────────────────┴─────────────────────────────────┴──────────┘ -``` ## Implementation details -Many online implementation details are borrowed from the PPOv2Trainer, which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). Here are some additional implementation details: +Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). Here are some additional implementation details: -1. When we turn on the EOS trick (i.e., replacing the score of completions that do not end with an EOS token with a scalar penalty score like `-1`) via `--non_eos_penalty --stop_token eos`, it's possible that the chosen and rejected completions have the same score. In this case, we will naively select the completion with the lower index and the chosen completion. +1. When we turn on the EOS trick (i.e., replacing the score of completions that do not end with an EOS token with a scalar penalty score like `-1`) via `--missing_eos_penalty`, it's possible that the chosen and rejected completions have the same score. In this case, we will naively select the completion with the lower index and the chosen completion. ## Benchmark experiments -To validate the online DPO implementation works, we ran experiments on the 1B and 6.9B models. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). +To validate the online DPO implementation works, we ran experiments on the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). ``` # 1B Online DPO experiment -accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ - examples/scripts/online_dpo.py \ +accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ --dataset_name trl-lib/tldr \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_tldr \ - --per_device_train_batch_size 16 \ - --gradient_accumulation_steps 4 \ - --local_rollout_forward_batch_size 32 \ - --num_epochs 1 \ - --num_mini_batches 1 \ - --total_episodes 1000000 \ - --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ - --save_strategy no \ - --non_eos_penalty \ - --stop_token eos \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-deduped-tldr-online-dpo \ --beta 0.1 \ - --response_length 53 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --logging_steps 20 \ + --save_steps 0.1 \ --push_to_hub +# 2.8B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-2.8b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --bf16 \ + --logging_steps 20 \ + --save_steps 0.1 \ + --push_to_hub \ + # 6.9B Online DPO experiment accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/online_dpo.py \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \ --dataset_name trl-lib/tldr \ - --learning_rate 3e-6 \ - --output_dir models/minimal/online_dpo_tldr_6.9b \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 16 \ - --local_rollout_forward_batch_size 8 \ - --num_epochs 1 \ - --num_mini_batches 1 \ - --total_episodes 1000000 \ - --model_name_or_path EleutherAI/pythia-6.9b-deduped \ - --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \ - --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ - --save_strategy no \ - --non_eos_penalty \ - --stop_token eos \ + --learning_rate 5.0e-7 \ + --output_dir pythia-6.9b-deduped-tldr-online-dpo \ --beta 0.1 \ - --response_length 53 \ - --push_to_hub + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --bf16 \ + --logging_steps 20 \ + --save_steps 0.1 \ + --push_to_hub \ ``` Checkpoints and experiment tracking are available at: -- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr) +- [🤗 Model checkpoint](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) - [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) @@ -292,13 +191,13 @@ To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the che For more information on how to use judges, see [Judges](judges). ```bash -$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000 +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 33.00% -python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000 +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 41.50% -python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/online_dpo_tldr --judge_model gpt-4o-mini --num_examples 1000 +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 62.60% -python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/online_dpo_tldr_6.9b --judge_model gpt-4o-mini --num_examples 1000 +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 Model win rate: 74.20% ``` @@ -307,24 +206,27 @@ We can then plot the RLHF scaling chart. ```python import matplotlib.pyplot as plt -data = { - "SFT": [[1e9, 6.9e9], [0.33, 0.415]], - "Online DPO": [[1e9, 6.9e9], [0.626, 0.742]], +results = { + "SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316}, + "online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796}, + "offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701}, } -for model, (x, y) in data.items(): - plt.scatter(x, y, label=model) + +plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o") +plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o") +plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o") plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary") -plt.title("RLHF scaling by model size") -plt.xlabel("Model size") -plt.ylabel("Win rate against reference summaries\n(according to GPT-4o mini)") plt.xscale("log") -plt.xlim(5e8, 1.2e10) -plt.xticks([1e9, 1e10], ["1B", "10B"]) +plt.xlabel("Model size") +plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)") +plt.title("DPO scaling by model size") plt.legend() +plt.xlim(5e8, 1.2e10) +plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"]) plt.grid(True, which="both", ls="--", c="0.7") plt.tight_layout() -plt.savefig("plot.png") +plt.show() ``` diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index d3f9a10ec7b..14515eeabfa 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -286,15 +286,13 @@ def on_step_end(self, args, state, control, **kwargs): model.eval() # Generate completions - generation_config = GenerationConfig( - max_new_tokens=args.completion_length, min_new_tokens=args.completion_length - ) + generation_config = GenerationConfig(max_new_tokens=args.max_new_tokens, min_new_tokens=args.max_new_tokens) inputs = self.inputs.to(args.device) _, context_length = inputs["input_ids"].shape output = model.generate(**inputs, generation_config=generation_config) # Get only the completions - completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length + completion_ids = output[:, context_length:] # After the first EOS token, replace all tokens with padding tokens completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 0c55c28006c..f1c1ccab4a8 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -16,8 +16,8 @@ class OnlineDPOConfig(TrainingArguments): Args: reward_model_path (`Optional[str]`, *optional*, defaults to `None`): Path to the reward model. - completion_length (`int`, *optional*, defaults to `53`): - Length of the completions to generate. + max_new_tokens (`int`, *optional*, defaults to `64`): + The maximum number of tokens to generate per completion. temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): @@ -35,7 +35,7 @@ class OnlineDPOConfig(TrainingArguments): """ reward_model_path: Optional[str] = None - completion_length: int = 53 + max_new_tokens: int = 53 temperature: float = 0.9 missing_eos_penalty: Optional[float] = None beta: float = 0.1 diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 08956a1aad7..df188853ebc 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -243,11 +243,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() - # Sample 2 completations per prompt of size `completion_length` from the model + # Sample 2 completations per prompt of size `max_new_tokens` from the model inputs = self._prepare_inputs(inputs) generation_config = GenerationConfig( - max_new_tokens=self.args.completion_length, - min_new_tokens=self.args.completion_length, + max_new_tokens=self.args.max_new_tokens, + min_new_tokens=self.args.max_new_tokens, temperature=self.args.temperature, top_k=0.0, top_p=1.0, @@ -265,7 +265,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, ) del inputs - completion_ids = output[:, context_length:] # completions.shape[1] == self.args.completion_length + completion_ids = output[:, context_length:] completion_ids, completion_mask = truncate_right( completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id ) From b36dc0ea1c33be1dd18c1c1e082796257223658e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 16:38:36 +0000 Subject: [PATCH 79/92] Fix test --- tests/test_online_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index e69ddfe2f99..d5242b2b464 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -84,7 +84,7 @@ def test_online_dpo_trainer_training(self): model=self.model, ref_model=self.model, reward_model=self.reward_model, - config=training_args, + args=training_args, tokenizer=self.tokenizer, train_dataset=self.dummy_dataset, eval_dataset=self.dummy_dataset, From 8db4b718ff673408901485977b02b1849bc2619e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 19:39:06 +0000 Subject: [PATCH 80/92] Revert log freq --- examples/scripts/dpo_online.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 9f79cd789ac..4e9ebc6a0b6 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -93,7 +93,6 @@ def prepare_dataset(row): dataset = dataset.map(prepare_dataset, num_proc=training_args.dataset_num_proc) prompts = dataset[args.dataset_test_split]["prompt"][:8] - log_completions_callback = LogCompletionsCallback(prompts, freq=training_args.logging_steps) trainer = OnlineDPOTrainer( model=model, @@ -103,6 +102,7 @@ def prepare_dataset(row): train_dataset=dataset[args.dataset_train_split], eval_dataset=dataset[args.dataset_test_split], tokenizer=tokenizer, - callbacks=[log_completions_callback], ) + log_completions_callback = LogCompletionsCallback(prompts, trainer) + trainer.add_callback(log_completions_callback) trainer.train() From 4439a9aced3da643ee2ee68bbe187ed003acea3a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 19:49:33 +0000 Subject: [PATCH 81/92] Fix docs --- docs/source/online_dpo_trainer.md | 11 ++-- trl/trainer/callbacks.py | 83 ------------------------------- 2 files changed, 5 insertions(+), 89 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 3847fb9c2ee..f15828abd59 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -82,14 +82,13 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. * `objective/scores`: The mean scores returned by the reward model / environment. * `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. -* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. * `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions. * `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions. +* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. * `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. * `logps/chosen`: The mean log probabilities of the chosen completions. * `logps/rejected`: The mean log probabilities of the rejected completions. -* `val/contain_eos+token`: The fraction of completions which contain and EOS token. -* `lr`: lr: The current learning rate used by the optimizer. +* `val/contain_eos_token`: The fraction of completions which contain an EOS token. ## Cookbook @@ -158,7 +157,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --bf16 \ --logging_steps 20 \ --save_steps 0.1 \ - --push_to_hub \ + --push_to_hub # 6.9B Online DPO experiment accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ @@ -176,9 +175,9 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 \ --bf16 \ - --logging_steps 20 \ + --logging_steps 1 \ --save_steps 0.1 \ - --push_to_hub \ + --push_to_hub ``` Checkpoints and experiment tracking are available at: diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 14515eeabfa..25491140066 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -30,12 +30,10 @@ TrainerState, TrainingArguments, ) -from transformers.integrations import WandbCallback from transformers.trainer_utils import has_length from ..models.utils import unwrap_model_for_generation from .judges import BaseRankJudge -from .utils import truncate_right if is_deepspeed_available(): @@ -235,84 +233,3 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra if self.trainer.accelerator.is_main_process: win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) self.trainer.log({"eval_win_rate": win_rate}) - - -class LogCompletionsCallback(WandbCallback): - r""" - A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases. - - Usage: - ```python - prompts = ["The capital of France is", "The opposite of up is"] - trainer = DPOTrainer(..., callbacks=[LogCompletionsCallback(prompts)]) - ``` - - Args: - prompts (`List[str]`): - The prompts to generate completions for. - freq (`Optional[int]`, *optional*, defaults to `None`): - The frequency at which to log completions. If not provided, defaults to `logging_steps`. - """ - - def __init__(self, prompts: List[str], freq: int = None): - super().__init__() - self.prompts = prompts - self.inputs = None # will be tokenized in on_train_begin - self.table = [] - self._last_logged_step = -1 - self.freq = freq - - def on_train_begin(self, args, state, control, **kwargs): - tokenizer = kwargs["tokenizer"] - self.inputs = tokenizer(self.prompts, return_tensors="pt", padding=True, truncation=True) - - def on_step_end(self, args, state, control, **kwargs): - # Only log from the main process - if not state.is_world_process_zero: - return - - # Only log once per step (this method may be called multiple times) - if state.global_step == self._last_logged_step: - return - - # Only log every `freq` steps (if no `freq` is provided, log every `logging_steps` steps) - freq = self.freq or state.logging_steps - if state.global_step % freq != 0: - return - - # Get the model and tokenizer - model = kwargs["model"] - tokenizer = kwargs["tokenizer"] - model.eval() - - # Generate completions - generation_config = GenerationConfig(max_new_tokens=args.max_new_tokens, min_new_tokens=args.max_new_tokens) - inputs = self.inputs.to(args.device) - _, context_length = inputs["input_ids"].shape - output = model.generate(**inputs, generation_config=generation_config) - - # Get only the completions - completion_ids = output[:, context_length:] - - # After the first EOS token, replace all tokens with padding tokens - completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) - - # Decode the prompts and completions - prompts = [ - p.replace(tokenizer.pad_token, "") - for p in tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False) - ] - completions = [ - c.replace(tokenizer.pad_token, "") - for c in tokenizer.batch_decode(completion_ids, skip_special_token=False) - ] - - # Build the data to log - global_step = [str(state.global_step)] * len(prompts) - data = list(zip(global_step, prompts, completions)) - self.table.extend(data) - table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table) - self._wandb.log({"completions": table}) - - # Save the last logged step, so we don't log the same completions multiple times - self._last_logged_step = state.global_step From 35eff1d25def45de7694d0b97ff2905c2f7aec57 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 19:51:25 +0000 Subject: [PATCH 82/92] Fix tests aain! --- docs/source/online_dpo_trainer.md | 3 ++- tests/test_online_dpo_trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index f15828abd59..1df16c302c5 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -160,7 +160,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --push_to_hub # 6.9B Online DPO experiment -accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ examples/scripts/dpo_online.py \ --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \ --reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \ @@ -175,6 +175,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 \ --bf16 \ + --gradient_checkpointing \ --logging_steps 1 \ --save_steps 0.1 \ --push_to_hub diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index d5242b2b464..f56c4faa192 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -93,4 +93,4 @@ def test_online_dpo_trainer_training(self): trainer.train() # Check if training loss is available - self.assertIn("loss/policy_avg", trainer.state.log_history[-1]) + self.assertIn("train_loss", trainer.state.log_history[-1]) From a64721df82a700191f386df3974229fd958a54bc Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 19:55:15 +0000 Subject: [PATCH 83/92] Fix typo --- docs/source/online_dpo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 1df16c302c5..68aa9c8b9ba 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -176,7 +176,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --missing_eos_penalty 1.0 \ --bf16 \ --gradient_checkpointing \ - --logging_steps 1 \ + --logging_steps 20 \ --save_steps 0.1 \ --push_to_hub ``` From ae4a1ed8f5e75de0f52c53249edb687bbccda6d9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 20:02:13 +0000 Subject: [PATCH 84/92] Revert --- docs/source/online_dpo_trainer.md | 7 +-- trl/trainer/callbacks.py | 83 +++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 68aa9c8b9ba..fcd0b867df2 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -82,13 +82,14 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. * `objective/scores`: The mean scores returned by the reward model / environment. * `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. * `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions. * `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions. -* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. * `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. * `logps/chosen`: The mean log probabilities of the chosen completions. * `logps/rejected`: The mean log probabilities of the rejected completions. -* `val/contain_eos_token`: The fraction of completions which contain an EOS token. +* `val/contain_eos+token`: The fraction of completions which contain and EOS token. +* `lr`: lr: The current learning rate used by the optimizer. ## Cookbook @@ -157,7 +158,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --bf16 \ --logging_steps 20 \ --save_steps 0.1 \ - --push_to_hub + --push_to_hub \ # 6.9B Online DPO experiment accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 25491140066..14515eeabfa 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -30,10 +30,12 @@ TrainerState, TrainingArguments, ) +from transformers.integrations import WandbCallback from transformers.trainer_utils import has_length from ..models.utils import unwrap_model_for_generation from .judges import BaseRankJudge +from .utils import truncate_right if is_deepspeed_available(): @@ -233,3 +235,84 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra if self.trainer.accelerator.is_main_process: win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) self.trainer.log({"eval_win_rate": win_rate}) + + +class LogCompletionsCallback(WandbCallback): + r""" + A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases. + + Usage: + ```python + prompts = ["The capital of France is", "The opposite of up is"] + trainer = DPOTrainer(..., callbacks=[LogCompletionsCallback(prompts)]) + ``` + + Args: + prompts (`List[str]`): + The prompts to generate completions for. + freq (`Optional[int]`, *optional*, defaults to `None`): + The frequency at which to log completions. If not provided, defaults to `logging_steps`. + """ + + def __init__(self, prompts: List[str], freq: int = None): + super().__init__() + self.prompts = prompts + self.inputs = None # will be tokenized in on_train_begin + self.table = [] + self._last_logged_step = -1 + self.freq = freq + + def on_train_begin(self, args, state, control, **kwargs): + tokenizer = kwargs["tokenizer"] + self.inputs = tokenizer(self.prompts, return_tensors="pt", padding=True, truncation=True) + + def on_step_end(self, args, state, control, **kwargs): + # Only log from the main process + if not state.is_world_process_zero: + return + + # Only log once per step (this method may be called multiple times) + if state.global_step == self._last_logged_step: + return + + # Only log every `freq` steps (if no `freq` is provided, log every `logging_steps` steps) + freq = self.freq or state.logging_steps + if state.global_step % freq != 0: + return + + # Get the model and tokenizer + model = kwargs["model"] + tokenizer = kwargs["tokenizer"] + model.eval() + + # Generate completions + generation_config = GenerationConfig(max_new_tokens=args.max_new_tokens, min_new_tokens=args.max_new_tokens) + inputs = self.inputs.to(args.device) + _, context_length = inputs["input_ids"].shape + output = model.generate(**inputs, generation_config=generation_config) + + # Get only the completions + completion_ids = output[:, context_length:] + + # After the first EOS token, replace all tokens with padding tokens + completion_ids, _ = truncate_right(completion_ids, tokenizer.eos_token_id, tokenizer.pad_token_id) + + # Decode the prompts and completions + prompts = [ + p.replace(tokenizer.pad_token, "") + for p in tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False) + ] + completions = [ + c.replace(tokenizer.pad_token, "") + for c in tokenizer.batch_decode(completion_ids, skip_special_token=False) + ] + + # Build the data to log + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions)) + self.table.extend(data) + table = self._wandb.Table(columns=["step", "prompt", "completion"], data=self.table) + self._wandb.log({"completions": table}) + + # Save the last logged step, so we don't log the same completions multiple times + self._last_logged_step = state.global_step From 0b6ac0e845a3554d34ffe245f8ef2c9141219f1e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 20:24:31 +0000 Subject: [PATCH 85/92] Fix regression --- examples/scripts/dpo_online.py | 2 +- trl/trainer/callbacks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 4e9ebc6a0b6..09d92df9905 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -103,6 +103,6 @@ def prepare_dataset(row): eval_dataset=dataset[args.dataset_test_split], tokenizer=tokenizer, ) - log_completions_callback = LogCompletionsCallback(prompts, trainer) + log_completions_callback = LogCompletionsCallback(prompts) trainer.add_callback(log_completions_callback) trainer.train() diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 14515eeabfa..c7fb5f51964 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -304,7 +304,7 @@ def on_step_end(self, args, state, control, **kwargs): ] completions = [ c.replace(tokenizer.pad_token, "") - for c in tokenizer.batch_decode(completion_ids, skip_special_token=False) + for c in tokenizer.batch_decode(completion_ids, skip_special_tokens=False) ] # Build the data to log From dab37dc0759b8ca4da142a9583d6e72170adf85d Mon Sep 17 00:00:00 2001 From: lewtun Date: Tue, 27 Aug 2024 22:26:04 +0200 Subject: [PATCH 86/92] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/online_dpo_trainer.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index fcd0b867df2..661ba5f0349 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -88,7 +88,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. * `logps/chosen`: The mean log probabilities of the chosen completions. * `logps/rejected`: The mean log probabilities of the rejected completions. -* `val/contain_eos+token`: The fraction of completions which contain and EOS token. +* `val/contain_eos_token`: The fraction of completions which contain and EOS token. * `lr`: lr: The current learning rate used by the optimizer. @@ -101,7 +101,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. -* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty. This can help the model learn to generate more coherent completions. +* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtract to the score of completions that do not end with an EOS token a static scalar penalty. This can help the model learn to generate more coherent completions. ## What is my model doing exactly? @@ -113,7 +113,6 @@ To help you understand what your model is doing, we periodically log some sample Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). Here are some additional implementation details: -1. When we turn on the EOS trick (i.e., replacing the score of completions that do not end with an EOS token with a scalar penalty score like `-1`) via `--missing_eos_penalty`, it's possible that the chosen and rejected completions have the same score. In this case, we will naively select the completion with the lower index and the chosen completion. ## Benchmark experiments @@ -184,7 +183,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml Checkpoints and experiment tracking are available at: -- [🤗 Model checkpoint](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) +- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) - [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) From bb267fb12e646e613f588d26d9241e9045d97cf5 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 20:51:18 +0000 Subject: [PATCH 87/92] Fix DPO config test --- tests/test_trainers_args.py | 62 +++++++------------------------------ 1 file changed, 11 insertions(+), 51 deletions(-) diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 897dbd82634..a8b4ef028cc 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -223,70 +223,30 @@ def test_online_dpo(self): with tempfile.TemporaryDirectory() as tmp_dir: args = OnlineDPOConfig( tmp_dir, - run_name="dummy_run_name", - sanity_check=True, - num_mini_batches=2, - total_episodes=100, - local_rollout_forward_batch_size=32, - num_sample_generations=20, - response_length=52, - stop_token="eos", - stop_token_id=123, + max_new_tokens=42, temperature=0.5, - penalty_reward_value=-2, - non_eos_penalty=True, - sft_model_path="EleutherAI/pythia-14m", - world_size=4, - num_total_batches=100, - micro_batch_size=32, - local_batch_size=64, - batch_size=256, - local_mini_batch_size=8, - mini_batch_size=32, - exp_name="dummy_exp_name", - reward_model_path="EleutherAI/pythia-14m", - num_epochs=2, - beta=0.1, - loss_type="ipo", - disable_dropout=False, + missing_eos_penalty=0.33, + beta=0.6, + loss_type="hinge", + dataset_num_proc=4, ) model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m") ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m") reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) trainer = OnlineDPOTrainer( - config=args, + args=args, tokenizer=tokenizer, model=model, ref_model=ref_model, reward_model=reward_model, train_dataset=dataset, ) - self.assertEqual(trainer.args.run_name, "dummy_run_name") - self.assertEqual(trainer.args.sanity_check, True) - self.assertEqual(trainer.args.num_mini_batches, 2) - self.assertEqual(trainer.args.total_episodes, 100) - self.assertEqual(trainer.args.local_rollout_forward_batch_size, 32) - self.assertEqual(trainer.args.num_sample_generations, 20) - self.assertEqual(trainer.args.response_length, 52) - self.assertEqual(trainer.args.stop_token, "eos") - self.assertEqual(trainer.args.stop_token_id, 123) + self.assertEqual(trainer.args.max_new_tokens, 42) self.assertEqual(trainer.args.temperature, 0.5) - self.assertEqual(trainer.args.penalty_reward_value, -2) - self.assertEqual(trainer.args.non_eos_penalty, True) - self.assertEqual(trainer.args.sft_model_path, "EleutherAI/pythia-14m") - # self.assertEqual(trainer.args.world_size, 4) - # self.assertEqual(trainer.args.num_total_batches, 100) - # self.assertEqual(trainer.args.micro_batch_size, 32) - # self.assertEqual(trainer.args.local_batch_size, 64) - # self.assertEqual(trainer.args.batch_size, 256) - self.assertEqual(trainer.args.local_mini_batch_size, 8) - # self.assertEqual(trainer.args.mini_batch_size, 32) - self.assertEqual(trainer.args.exp_name, "dummy_exp_name") - self.assertEqual(trainer.args.reward_model_path, "EleutherAI/pythia-14m") - self.assertEqual(trainer.args.num_epochs, 2) - self.assertEqual(trainer.args.beta, 0.1) - self.assertEqual(trainer.args.loss_type, "ipo") - self.assertEqual(trainer.args.disable_dropout, False) + self.assertEqual(trainer.args.missing_eos_penalty, 0.33) + self.assertEqual(trainer.args.beta, 0.6) + self.assertEqual(trainer.args.loss_type, "hinge") + self.assertEqual(trainer.args.dataset_num_proc, 4) def test_orpo(self): tokenizer = AutoTokenizer.from_pretrained("gpt2") From 6b7d559f38687832cecb435f0e93c418e5a83044 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 20:53:45 +0000 Subject: [PATCH 88/92] Fix doc tree --- docs/source/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7648d2be46c..d9cad99ad3e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -31,12 +31,12 @@ title: PPOv2 Trainer - local: rloo_trainer title: RLOO Trainer - - local: online_dpo_trainer - title: Online DPO Trainer - local: best_of_n title: Best of N Sampling - local: dpo_trainer title: DPO Trainer + - local: online_dpo_trainer + title: Online DPO Trainer - local: kto_trainer title: KTO Trainer - local: bco_trainer From 56afe56d2c553930ea355c3711c0fe6ee4dfcdfc Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 28 Aug 2024 07:22:06 +0000 Subject: [PATCH 89/92] Clean docs moar --- docs/source/online_dpo_trainer.md | 69 +++++++++++++++++++------------ 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 661ba5f0349..00a9f9c50ca 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -1,10 +1,21 @@ # Online DPO Trainer -TRL supports post-training LLMs with online DPO ([Guo et al., 2024](https://huggingface.co/papers/2402.04792)). The idea of online DPO is to generate completions per batch of prompts and have either a reward model or an LLM judge rank the responses as chosen or rejected. Then the model is updated with the ranked responses using the DPO loss. +## Overview -While [Guo et al. (2024)](https://huggingface.co/papers/2402.04792) used an LLM judge to score model completions, the current implementation only supports reward models -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. +Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel. -## Get started +The abstract from the paper is the following: + +> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator. + +The current implementation uses reward models for scoring completions -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. + +This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching). + +## Usage tips + +> [!IMPORTANT] +> Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training. The basic API is as follows: @@ -17,6 +28,7 @@ from transformers import ( AutoTokenizer, ) NUM_DUMMY_SAMPLES = 100 + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # The model to optimise @@ -25,15 +37,18 @@ model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct ref_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct") # The model to score completions with. In practice, you will need a reward model. reward_model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", num_labels=1) + train_dataset = Dataset.from_dict( {"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES}) eval_dataset = Dataset.from_dict( {"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES}) + +args = OnlineDPOConfig(output_dir="online-dpo-model") trainer = OnlineDPOTrainer( model=model, ref_model=ref_model, reward_model=reward_model, - args=OnlineDPOConfig(output_dir="online-dpo-model"), + args=args, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -53,15 +68,20 @@ python examples/scripts/dpo_online.py \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 32 \ --num_train_epochs 3 \ - --completion_length 53 \ + --max_new_tokens 53 \ --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 \ --push_to_hub ``` -## Expected dataset format +Tips: + +* `objective/rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up. +* We recommend using the "EOS trick" via the `--missing_eos_penalty` argument, which subtracts from the rewards a fixed scalar penalty for completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. + +### Expected dataset format -Unlike standard DPO where one provides a dataset with chosen and rejected columns, for online DPO one just needs a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will want to wrap your prompts in the messages format and then apply the chat template as follows: +Unlike offline DPO, where one provides a dataset with chosen and rejected columns, online DPO only requires a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will need to wrap your prompts in the messages format and then apply the chat template as follows: ```python def prepare_dataset(row): @@ -72,7 +92,7 @@ def prepare_dataset(row): dataset = prepare_dataset(dataset) ``` -## Explanation of the logged metrics +### Explanation of the logged metrics The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) @@ -88,35 +108,22 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. * `logps/chosen`: The mean log probabilities of the chosen completions. * `logps/rejected`: The mean log probabilities of the rejected completions. -* `val/contain_eos_token`: The fraction of completions which contain and EOS token. -* `lr`: lr: The current learning rate used by the optimizer. - - -## Cookbook - -> [!IMPORTANT] -> Make sure the SFT model and reward model use the _same_ chat template. Otherwise you may find the model completions are scored incorrectly. - - -* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. -* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. -* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. -* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtract to the score of completions that do not end with an EOS token a static scalar penalty. This can help the model learn to generate more coherent completions. +* `val/contain_eos_token`: The fraction of completions which contain an EOS token. ## What is my model doing exactly? -To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of generations. +To help you understand what your model is doing, we periodically log some sample completions from the model via [`LogCompletionsCallback`]. You can find an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/hlzevfro?nw=nwuserlewtun), which allows you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of prompts to generate for in [`LogCompletionsCallback`]. ## Implementation details -Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). Here are some additional implementation details: +Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). ## Benchmark experiments -To validate the online DPO implementation works, we ran experiments on the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). +To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). ``` @@ -184,7 +191,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml Checkpoints and experiment tracking are available at: - [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) -- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0) To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. @@ -229,7 +236,15 @@ plt.tight_layout() plt.show() ``` - ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online_dpo_scaling.png) The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended. + +## OnlineDPOTrainer + +[[autodoc]] OnlineDPOTrainer + + +## OnlineDPOConfig + +[[autodoc]] OnlineDPOConfig \ No newline at end of file From f78ff61317717937a217b07828b7e702f4295979 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 28 Aug 2024 07:35:58 +0000 Subject: [PATCH 90/92] Add docstring --- docs/source/online_dpo_trainer.md | 2 +- trl/trainer/online_dpo_trainer.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md index 00a9f9c50ca..0b7c9ffcbe4 100644 --- a/docs/source/online_dpo_trainer.md +++ b/docs/source/online_dpo_trainer.md @@ -14,7 +14,7 @@ This post-training method was contributed by [Michael Noukhovitch](https://huggi ## Usage tips -> [!IMPORTANT] +> [!WARNING] > Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training. The basic API is as follows: diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index df188853ebc..2df24682b17 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -52,6 +52,43 @@ class OnlineDPOTrainer(Trainer): + r""" + Initialize OnlineDPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + reward_model (`transformers.PreTrainedModel`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + judge (`BasePairwiseJudge`): + The judge to use for pairwise comparison of model completions. + args (`OnlineDPOConfig`): + The online DPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + _tag_names = ["trl", "online-dpo"] def __init__( From bc13c33898e47367715da4934084c1f784f46e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 28 Aug 2024 08:45:09 +0000 Subject: [PATCH 91/92] raise NotImplemented error for judge --- trl/trainer/online_dpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 2df24682b17..0d1edce9da4 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -117,6 +117,8 @@ def __init__( ) elif reward_model is None and judge is None: raise ValueError("Either `reward_model` or `judge` must be provided.") + elif reward_model is None and judge is not None: + raise NotImplementedError("Using `judge` is not yet supported.") self.reward_model = reward_model self.judge = judge From 57eb67338e4a20d41ab6fb33a801fd973bca10c5 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 28 Aug 2024 13:09:47 +0000 Subject: [PATCH 92/92] Refactor cache clearning --- trl/trainer/__init__.py | 1 + trl/trainer/online_dpo_trainer.py | 20 ++++---------------- trl/trainer/utils.py | 23 +++++++++++++++++++++++ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 0d66c9db8fe..77f3162aef1 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -86,6 +86,7 @@ RunningMoments, disable_dropout_in_model, peft_module_casting_to_bf16, + empty_cache, ) # isort: on diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 0d1edce9da4..a95c2d2c0da 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -18,9 +18,6 @@ from transformers.utils import ( is_apex_available, is_sagemaker_mp_enabled, - is_torch_mlu_available, - is_torch_npu_available, - is_torch_xpu_available, logging, ) @@ -29,6 +26,7 @@ from .online_dpo_config import OnlineDPOConfig from .utils import ( DPODataCollatorWithPadding, + empty_cache, get_reward, prepare_deepspeed, trl_sanitze_kwargs_for_tagging, @@ -320,7 +318,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, # Take the completion tokens logprob logprobs = torch.take_along_dim(all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) del output, logits, all_logprobs # free memory - self.empty_cache() + empty_cache() # Same for the reference model with torch.no_grad(): @@ -329,7 +327,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.take_along_dim(ref_all_logprobs, completion_ids.unsqueeze(-1), dim=2).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs # free memory - self.empty_cache() + empty_cache() # Get the reward from the reward model with torch.no_grad(): @@ -415,7 +413,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 ): - self.empty_cache() + empty_cache() kwargs = {} @@ -469,16 +467,6 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - def empty_cache(self): - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_mlu_available(): - torch.mlu.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - else: - torch.cuda.empty_cache() - @wraps(Trainer.push_to_hub) def push_to_hub( self, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 500d5517f4d..a71797d4b2b 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -36,6 +36,11 @@ TrainerState, TrainingArguments, ) +from transformers.utils import ( + is_torch_mlu_available, + is_torch_npu_available, + is_torch_xpu_available, +) from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available from ..trainer.model_config import ModelConfig @@ -1241,3 +1246,21 @@ def truncate_right( output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) return output_ids, mask + + +def empty_cache() -> None: + """Empties the cache of the available torch device. + + This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) + and empties the cache of the first available device it finds. + + If none of the specific devices are available, it defaults to emptying the CUDA cache. + """ + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache()