diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b032f228c8..65434bb095 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -123,11 +123,11 @@ def prepare_for_training_mode(f): @functools.wraps(f) def wrapper(self, *args, **kwargs): # Enable training mode - if hasattr(self, model) and hasattr(self.model, "for_training"): + if hasattr(self, 'model') and hasattr(self.model, "for_training"): self.model.for_training() output = f(self, *args, **kwargs) # Return inference mode - if hasattr(self, model) and hasattr(self.model, "for_inference"): + if hasattr(self, 'model') and hasattr(self.model, "for_inference"): self.model.for_inference() return output return wrapper