diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 594ba84f8f..08956a1aad 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -252,6 +252,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, top_k=0.0, top_p=1.0, do_sample=True, + use_cache=False if self.args.gradient_checkpointing else True, ) num_examples, context_length = inputs["prompt_input_ids"].shape prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)