Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
is_apex_available,
is_sagemaker_mp_enabled,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
Expand All @@ -28,7 +27,13 @@
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .online_dpo_config import OnlineDPOConfig
from .utils import DPODataCollatorWithPadding, get_reward, prepare_deepspeed, truncate_right
from .utils import (
DPODataCollatorWithPadding,
get_reward,
prepare_deepspeed,
trl_sanitze_kwargs_for_tagging,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this since it was missing in the imports

truncate_right,
)


if is_apex_available():
Expand Down Expand Up @@ -429,12 +434,8 @@ def empty_cache(self):
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
# elif is_torch_musa_available():
# torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down