Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return SequentialSampler(self.train_dataset)

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future: now that we have LogCompletions callback, it might be possible to enable the generative aspects of this method directly as a callback. We'd probably have to extend the LogCompletions callback to check if a reference model exists and generate for that too, but that seems better than having this code duplicated all over our preference trainers

"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1407,7 +1408,7 @@ def evaluation_loop(
"prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand All @@ -847,7 +848,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -938,7 +939,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded = self.generate_from_model(self.model, random_batch)

self.log(
{
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with compute_loss_context_manager:
Expand All @@ -1561,7 +1562,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1672,7 +1673,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch)

self.log(
{
Expand Down
8 changes: 5 additions & 3 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def generalized_jsd_loss(
else:
return jsd

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
Expand Down Expand Up @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No

return generated_tokens, new_attention_mask, new_labels

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
"""
Perform a training step for the Generalized Knowledge Distillation (GKD) model.

Expand All @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels

loss = super().training_step(model, inputs)
loss = super().training_step(model, inputs, num_items_in_batch)
return loss

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand Down Expand Up @@ -1264,7 +1265,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return SequentialSampler(self.train_dataset)

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1383,7 +1384,7 @@ def evaluation_loop(
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def gather_mean(tensor):
self.stats["beta"].append(self.beta)
self.stats["mixture_coef"].append(self.mixture_coef)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None

return self.accelerator.prepare(eval_dataloader)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input.
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand All @@ -866,7 +867,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -957,7 +958,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded = self.generate_from_model(self.model, random_batch)

self.log(
{
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_reward_data_collator:
warnings.warn(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ def gather_mean(tensor):
self.stats["alpha"].append(self.alpha)
self.stats["beta"].append(self.beta)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input
Expand Down