diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 91461a9b0d0..c6ce2d49025 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1260,6 +1260,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1407,7 +1408,7 @@ def evaluation_loop( "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]), "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5847cb182be..5e74fdaceb0 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -828,6 +828,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -847,7 +848,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -938,7 +939,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8b9843cd916..082a627ce00 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1547,6 +1547,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() with compute_loss_context_manager: @@ -1561,7 +1562,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1672,7 +1673,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) self.log( { diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557db..49e93e269b6 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -215,7 +215,7 @@ def generalized_jsd_loss( else: return jsd - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No return generated_tokens, new_attention_mask, new_labels - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels - loss = super().training_step(model, inputs) + loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index ab9ba87e413..7f324248128 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1234,6 +1234,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1264,7 +1265,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1383,7 +1384,7 @@ def evaluation_loop( "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index db0c3046b37..73aab7899ae 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -328,7 +328,9 @@ def gather_mean(tensor): self.stats["beta"].append(self.beta) self.stats["mixture_coef"].append(self.mixture_coef) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index ffc407b57db..c480c61fc52 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input. diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 4edbf9b1a5a..123f9352080 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -844,6 +844,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -866,7 +867,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -957,7 +958,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 787c6cbd54b..0ebdee68b44 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -266,6 +266,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_reward_data_collator: warnings.warn( diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 0255e6206f8..a1548758213 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -377,7 +377,9 @@ def gather_mean(tensor): self.stats["alpha"].append(self.alpha) self.stats["beta"].append(self.beta) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input