diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3230cdc207..9c032b28bd 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -16,5 +16,5 @@ from .llama import FastLlamaModel from .mistral import FastMistralModel from .qwen2 import FastQwen2Model -from .dpo import PatchDPOTrainer +from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index e7074350c3..5dc71f920a 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -14,6 +14,7 @@ __all__ = [ "PatchDPOTrainer", + "PatchKTOTrainer", ] try: @@ -127,4 +128,4 @@ def PatchDPOTrainer(): pass pass pass - +PatchKTOTrainer = PatchDPOTrainer