From 9e117078650ca673f1a8b1fb35f3e8044c757ef3 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 13 Jun 2024 00:12:04 +0000 Subject: [PATCH 1/2] prepare deepspeed accomodate fp16 and bf16 --- trl/trainer/ppov2_trainer.py | 20 ++++++++++++-------- trl/trainer/rloo_trainer.py | 14 +++++++++----- trl/trainer/utils.py | 9 +++++++-- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index dc74f3b3523..6442ee7afd6 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -190,8 +190,12 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed(self.reward_model, args.per_device_train_batch_size) - self.ref_policy = prepare_deepspeed(self.ref_policy, args.per_device_train_batch_size) + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 + ) else: self.ref_policy = self.ref_policy.to(self.accelerator.device) self.reward_model = self.reward_model.to(self.accelerator.device) @@ -447,14 +451,14 @@ def repeat_generator(): entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) approxkl = 0.5 * (logprobs_diff**2).mean() approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = pg_clipfrac + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = vf_clipfrac + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 02f69df5e52..6188f2d90ff 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -175,8 +175,12 @@ def __init__( self.eval_dataloader = accelerator.prepare(self.eval_dataloader) if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed(self.reward_model, args.per_device_train_batch_size) - self.ref_policy = prepare_deepspeed(self.ref_policy, args.per_device_train_batch_size) + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.bf16, args.fp16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.bf16, args.fp16 + ) self.deepspeed = self.model else: self.ref_policy = self.ref_policy.to(self.accelerator.device) @@ -367,9 +371,9 @@ def repeat_generator(): entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) approxkl = 0.5 * (logprobs_diff**2).mean() approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ - ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx - ] = pg_clipfrac + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index c0197336e7b..0aeabf8905c 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1001,7 +1001,9 @@ def forward( ) -def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int): +def prepare_deepspeed( + model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False +): """ Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and batch size. @@ -1024,10 +1026,13 @@ def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int): config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size config_kwargs = { "train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], - "bf16": {"enabled": True}, "prescale_gradients": False, "wall_clock_breakdown": False, } + if bf16: + config_kwargs["bf16"] = {"enabled": True} + elif fp16: + config_kwargs["fp16"] = {"enabled": True} else: if hasattr(model, "config"): hidden_size = ( From 676779905518c4b24508b3b781861ad6a7145f41 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 13 Jun 2024 17:05:25 +0000 Subject: [PATCH 2/2] precommit --- trl/trainer/ppov2_trainer.py | 12 ++++++------ trl/trainer/rloo_trainer.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 6442ee7afd6..e9f1eec8216 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -451,14 +451,14 @@ def repeat_generator(): entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) approxkl = 0.5 * (logprobs_diff**2).mean() approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac - ) + pg_clipfrac_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = pg_clipfrac pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - vf_clipfrac - ) + vf_clipfrac_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = vf_clipfrac entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6188f2d90ff..0725358d7af 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -371,9 +371,9 @@ def repeat_generator(): entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) approxkl = 0.5 * (logprobs_diff**2).mean() approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac - ) + pg_clipfrac_stats[ + ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx + ] = pg_clipfrac pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()