diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2b6293993d..84ab86dd49 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -85,6 +85,85 @@ def generate_with_clone(*args, **kwargs): pass pass + from transformers import Trainer + from transformers.trainer_pt_utils import nested_detach + @torch.no_grad() + def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,): + """ + Perform an evaluation step on `model` using `inputs`. + Subclass and override to inject custom behavior. + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + os.environ["UNSLOTH_RETURN_LOGITS"] = "1" + with torch.no_grad(): + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device) + outputs = model(**tokenized_output) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + os.environ["UNSLOTH_RETURN_LOGITS"] = "0" + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) import trl.trainer trainers = dir(trl.trainer) trainers = [x for x in trainers if x.endswith("_trainer")] @@ -95,6 +174,7 @@ def generate_with_clone(*args, **kwargs): if hasattr(current_trainer, unwrap): try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") except: continue + exec(f"Trainer.prediction_step=unsloth_prediction_step") pass pass