From b6ecab7dc50f64aac58edf1460019147e354b670 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 09:14:45 +0000 Subject: [PATCH 1/2] Remove MPS --- trl/trainer/online_dpo_trainer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 37fa10c14ea..594ba84f8f0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -19,7 +19,6 @@ is_apex_available, is_sagemaker_mp_enabled, is_torch_mlu_available, - is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -28,7 +27,13 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import DPODataCollatorWithPadding, get_reward, prepare_deepspeed, truncate_right +from .utils import ( + DPODataCollatorWithPadding, + get_reward, + prepare_deepspeed, + trl_sanitze_kwargs_for_tagging, + truncate_right, +) if is_apex_available(): @@ -429,12 +434,8 @@ def empty_cache(self): torch.xpu.empty_cache() elif is_torch_mlu_available(): torch.mlu.empty_cache() - # elif is_torch_musa_available(): - # torch.musa.empty_cache() elif is_torch_npu_available(): torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() else: torch.cuda.empty_cache() From 5dafcf868a9ef959a207e15f4b601b75f500b7a0 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 27 Aug 2024 11:04:25 +0000 Subject: [PATCH 2/2] Fix --- trl/trainer/online_dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 594ba84f8f0..08956a1aad7 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -252,6 +252,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, top_k=0.0, top_p=1.0, do_sample=True, + use_cache=False if self.args.gradient_checkpointing else True, ) num_examples, context_length = inputs["prompt_input_ids"].shape prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)