diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 37fa10c14ea..594ba84f8f0 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -19,7 +19,6 @@ is_apex_available, is_sagemaker_mp_enabled, is_torch_mlu_available, - is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -28,7 +27,13 @@ from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge from .online_dpo_config import OnlineDPOConfig -from .utils import DPODataCollatorWithPadding, get_reward, prepare_deepspeed, truncate_right +from .utils import ( + DPODataCollatorWithPadding, + get_reward, + prepare_deepspeed, + trl_sanitze_kwargs_for_tagging, + truncate_right, +) if is_apex_available(): @@ -429,12 +434,8 @@ def empty_cache(self): torch.xpu.empty_cache() elif is_torch_mlu_available(): torch.mlu.empty_cache() - # elif is_torch_musa_available(): - # torch.musa.empty_cache() elif is_torch_npu_available(): torch.npu.empty_cache() - elif is_torch_mps_available(min_version="2.0"): - torch.mps.empty_cache() else: torch.cuda.empty_cache()