From e085ba2d10a71a429c6153e1a0054deb34dc1153 Mon Sep 17 00:00:00 2001 From: Jingru Date: Tue, 31 Oct 2023 05:46:10 +0000 Subject: [PATCH 1/6] support parallel reward function --- trlx/data/configs.py | 2 + trlx/trainer/accelerate_base_trainer.py | 24 +++++++----- trlx/trainer/accelerate_ppo_trainer.py | 51 +++++++++++++++---------- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index a94073107..0f3610789 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -231,6 +231,8 @@ class TrainConfig: minibatch_size: Optional[int] = None + reward_only_in_main_process: bool = True + @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index e15fb06da..7809f2e3d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -387,19 +387,23 @@ def evaluate(self): # noqa: C901 if self.config.model.model_arch_type == "seq2seq": samples = samples[:, 1:].contiguous() - prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids)) - prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics( - self.accelerator.pad_across_processes( - [prompts.input_ids, samples, prompt_sizes.to(samples.device)], - dim=1, - pad_index=self.tokenizer.pad_token_id, - ) + prompt_sizes = torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat( + len(prompts.input_ids) ) + if self.config.train.reward_only_in_main_process: + prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics( + self.accelerator.pad_across_processes( + [prompts.input_ids, samples, prompt_sizes], + dim=1, + pad_index=self.tokenizer.pad_token_id, + ) + ) + metadata = gather_dict(metadata, self.accelerator.gradient_state) + else: + prompts = prompts.input_ids all_samples.extend(samples.tolist()) all_prompts.extend(prompts.tolist()) all_prompt_sizes.extend(prompt_sizes.tolist()) - - metadata = gather_dict(metadata, self.accelerator.gradient_state) all_metadata.append(metadata) desc = [ @@ -412,7 +416,7 @@ def evaluate(self): # noqa: C901 stats["time/generate"] = time() - generate_time - if self.accelerator.is_main_process: + if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process: str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) columns = ["prompt", "output"] diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 1a4801aaf..acf5c256e 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -289,18 +289,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) - padded_samples = self.accelerator.pad_across_processes( - samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False - ) - padded_prompts = self.accelerator.pad_across_processes( - prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False - ) - gathered_samples = self.accelerator.gather(padded_samples) - gathered_prompts = self.accelerator.gather(padded_prompts) - gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"} - if self.accelerator.is_main_process: + if self.config.train.reward_only_in_main_process: + padded_samples = self.accelerator.pad_across_processes( + samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + padded_prompts = self.accelerator.pad_across_processes( + prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + gathered_samples = self.accelerator.gather(padded_samples) + gathered_prompts = self.accelerator.gather(padded_prompts) + gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) + metadata = gather_dict(metadata) + else: + gathered_samples = samples + gathered_prompts = prompt_tensors + gathered_prompt_sizes = prompt_sizes + + if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True ) @@ -316,9 +323,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq **metadata, ) all_scores = [ - torch.tensor(score, dtype=torch.float, device=device).view( - -1, - ) + score.view(-1) + if isinstance(score, torch.Tensor) + else torch.tensor(score, dtype=torch.float, device=device).view(-1) for score in all_scores ] # Pad 0 reward on the ends @@ -327,17 +334,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) + if self.config.train.reward_only_in_main_process: + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) else: all_scores = None max_len = torch.tensor(0, dtype=torch.long, device=device) - if torch.distributed.is_initialized(): - torch.distributed.broadcast(max_len, 0) - scores = torch.empty((len(samples), max_len), device=device) - torch.distributed.scatter(scores, all_scores) + if self.config.train.reward_only_in_main_process: + if torch.distributed.is_initialized(): + torch.distributed.broadcast(max_len, 0) + scores = torch.empty((len(samples), max_len), device=device) + torch.distributed.scatter(scores, all_scores) + else: + scores = all_scores[0].clone().detach() else: - scores = all_scores[0].clone().detach() + scores = all_scores scores_mask = scores != -np.inf str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) From d283ee282ab78aa0b7f2ed3be4b2b069769a42a9 Mon Sep 17 00:00:00 2001 From: Jingru Date: Wed, 15 Nov 2023 05:51:25 +0000 Subject: [PATCH 2/6] feat: support parallel reward function --- trlx/trainer/accelerate_ppo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index acf5c256e..f004e0544 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -351,7 +351,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = all_scores scores_mask = scores != -np.inf - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + if self.config.train.reward_only_in_main_process: + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + else: + str_samples, str_prompts, str_outputs = all_str_samples, all_str_prompts, all_str_outputs # Pad the sample outputs outputs = self.tokenizer(str_outputs).input_ids From fd3d95ba0fc33da08beb51b6a1bca4036b75ec74 Mon Sep 17 00:00:00 2001 From: Jingru Date: Wed, 15 Nov 2023 19:58:08 +0800 Subject: [PATCH 3/6] Update accelerate_ppo_trainer.py --- trlx/trainer/accelerate_ppo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index f004e0544..424d3d873 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -352,9 +352,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores_mask = scores != -np.inf if self.config.train.reward_only_in_main_process: - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + _, _, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) else: - str_samples, str_prompts, str_outputs = all_str_samples, all_str_prompts, all_str_outputs + str_outputs = all_str_outputs # Pad the sample outputs outputs = self.tokenizer(str_outputs).input_ids From aa859886cdc5e717bbb21835a5146d4a24959bf9 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 17 Nov 2023 09:36:31 +0000 Subject: [PATCH 4/6] support parallel reward function --- trlx/trainer/accelerate_base_trainer.py | 60 +++++++++++++++---------- trlx/trainer/accelerate_ppo_trainer.py | 8 ++-- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 7809f2e3d..38044fa7f 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -419,8 +419,9 @@ def evaluate(self): # noqa: C901 if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process: str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) - columns = ["prompt", "output"] columns_data = [str_prompts, str_outputs] + if not self.config.train.reward_only_in_main_process: + columns_data = self.accelerator.gather_for_metrics(columns_data) metadata, *xs = all_metadata for k in metadata: @@ -443,41 +444,52 @@ def evaluate(self): # noqa: C901 rewards = torch.tensor([sum(reward) for reward in rewards], dtype=float) else: rewards = torch.tensor(rewards, dtype=float) - mean_reward = rewards.mean().item() - columns.append("reward") - if not isinstance(rewards, list): - rewards = rewards.tolist() - columns_data.append(rewards) - stats[f"reward/mean{sweep_suffix}"] = mean_reward + + if not self.config.train.reward_only_in_main_process: + rewards = self.accelerator.gather(rewards) + if self.accelerator.is_main_process: + mean_reward = rewards.mean().item() + + columns = ["prompt", "output", "reward"] + if not isinstance(rewards, list): + rewards = rewards.tolist() + columns_data.append(rewards) + stats[f"reward/mean{sweep_suffix}"] = mean_reward # additionally log any other metrics if self.metric_fn: logger.info("Computing metrics") metric_time = time() metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) - stats["time/metric"] = time() - metric_time + if not self.config.train.reward_only_in_main_process: + metrics = self.accelerator.gather_for_metrics(metrics) - mean_metrics = { - f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items() - } + if self.accelerator.is_main_process: + stats["time/metric"] = time() - metric_time - stats.update(mean_metrics) + mean_metrics = { + f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() + for k, xs in metrics.items() + } - for metric, values in metrics.items(): - # Skip metrics that are scalers since they represent aggregated values - if isinstance(values, float): - continue - columns.append(metric) - if not isinstance(values, list): - values = values.tolist() - columns_data.append(values) + stats.update(mean_metrics) + + for metric, values in metrics.items(): + # Skip metrics that are scalers since they represent aggregated values + if isinstance(values, float): + continue + columns.append(metric) + if not isinstance(values, list): + values = values.tolist() + columns_data.append(values) # Prepend the sweep argument along with samples - if self.generate_sweep_kwarg: - columns.insert(0, gen_sweep_arg) - columns_data.insert(0, [gen_sweep_value] * len(samples)) + if self.accelerator.is_main_process: + if self.generate_sweep_kwarg: + columns.insert(0, gen_sweep_arg) + columns_data.insert(0, [gen_sweep_value] * len(samples)) - table.append(list(zip(*columns_data))) + table.append(list(zip(*columns_data))) # Log and display evaluation metrics logger.info("Summarizing evaluation") diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 424d3d873..1060582e9 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -344,17 +344,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq if torch.distributed.is_initialized(): torch.distributed.broadcast(max_len, 0) scores = torch.empty((len(samples), max_len), device=device) - torch.distributed.scatter(scores, all_scores) + torch.distributed.scatter(scores, all_scores) # scores is one shard of one process after scatter else: - scores = all_scores[0].clone().detach() + scores = all_scores[0].clone().detach() # shard of one process else: - scores = all_scores + scores = all_scores.clone().detach() # shard of one process + # `all_scores` no longer used, no need to gather it scores_mask = scores != -np.inf if self.config.train.reward_only_in_main_process: _, _, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) else: str_outputs = all_str_outputs + # `all_str_outputs` no longer used, no need to gather it # Pad the sample outputs outputs = self.tokenizer(str_outputs).input_ids From 68385805130d70e8a19ba7c9aa3e6543fa9e3dfe Mon Sep 17 00:00:00 2001 From: Jingru Date: Wed, 6 Dec 2023 09:37:48 +0000 Subject: [PATCH 5/6] support parallel reward function --- trlx/trainer/accelerate_base_trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 38044fa7f..bebd85d24 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -419,9 +419,11 @@ def evaluate(self): # noqa: C901 if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process: str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) - columns_data = [str_prompts, str_outputs] - if not self.config.train.reward_only_in_main_process: - columns_data = self.accelerator.gather_for_metrics(columns_data) + if self.accelerator.is_main_process: + columns = ["prompt", "output"] + columns_data = [str_prompts, str_outputs] + if not self.config.train.reward_only_in_main_process: + columns_data = self.accelerator.gather_for_metrics(columns_data) metadata, *xs = all_metadata for k in metadata: @@ -445,12 +447,12 @@ def evaluate(self): # noqa: C901 else: rewards = torch.tensor(rewards, dtype=float) - if not self.config.train.reward_only_in_main_process: - rewards = self.accelerator.gather(rewards) if self.accelerator.is_main_process: + if not self.config.train.reward_only_in_main_process: + rewards = self.accelerator.gather(rewards) mean_reward = rewards.mean().item() - columns = ["prompt", "output", "reward"] + columns.append("reward") if not isinstance(rewards, list): rewards = rewards.tolist() columns_data.append(rewards) From 2cf8af54a4424f6494a9fe1f437d24364d6d697f Mon Sep 17 00:00:00 2001 From: Jingru Date: Thu, 7 Dec 2023 07:54:07 +0000 Subject: [PATCH 6/6] support parallel reward function --- trlx/trainer/accelerate_base_trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index bebd85d24..3a8456427 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -421,9 +421,11 @@ def evaluate(self): # noqa: C901 if self.accelerator.is_main_process: columns = ["prompt", "output"] - columns_data = [str_prompts, str_outputs] - if not self.config.train.reward_only_in_main_process: - columns_data = self.accelerator.gather_for_metrics(columns_data) + + # gather should be invoked in every process, not just the main process + columns_data = [str_prompts, str_outputs] + if not self.config.train.reward_only_in_main_process: + columns_data = self.accelerator.gather_for_metrics(columns_data) metadata, *xs = all_metadata for k in metadata: @@ -447,9 +449,11 @@ def evaluate(self): # noqa: C901 else: rewards = torch.tensor(rewards, dtype=float) + # gather should be invoked in every process, not just the main process + if not self.config.train.reward_only_in_main_process: + rewards = self.accelerator.gather(rewards) + if self.accelerator.is_main_process: - if not self.config.train.reward_only_in_main_process: - rewards = self.accelerator.gather(rewards) mean_reward = rewards.mean().item() columns.append("reward")