diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 74b74c60b2bd..d936b5b1791a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -134,6 +134,7 @@ CONFIG_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + can_return_loss, find_labels, get_full_repo_name, is_apex_available, @@ -625,6 +626,7 @@ def __init__( self.use_tune_checkpoints = False default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) # Internal variables to keep track of the original batch size @@ -3190,6 +3192,14 @@ def prediction_step( logits and labels (each being optional). """ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): @@ -3198,7 +3208,7 @@ def prediction_step( ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. - if has_labels: + if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] @@ -3208,7 +3218,7 @@ def prediction_step( with torch.no_grad(): if is_sagemaker_mp_enabled(): raw_outputs = smp_forward_only(model, inputs) - if has_labels: + if has_labels or loss_without_labels: if isinstance(raw_outputs, dict): loss_mb = raw_outputs["loss"] logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) @@ -3226,7 +3236,7 @@ def prediction_step( logits_mb = raw_outputs logits = smp_nested_concat(logits_mb) else: - if has_labels: + if has_labels or loss_without_labels: with self.compute_loss_context_manager(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index f6a5b8d49499..7701145bf69a 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -38,6 +38,7 @@ PaddingStrategy, TensorType, cached_property, + can_return_loss, expand_dims, find_labels, flatten_dict, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 47619d279441..1d9201b95dc9 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -336,6 +336,28 @@ def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs) +def can_return_loss(model_class): + """ + Check if a given model can return loss. + + Args: + model_class (`type`): The class of the model. + """ + model_name = model_class.__name__ + if model_name.startswith("TF"): + signature = inspect.signature(model_class.call) + elif model_name.startswith("Flax"): + signature = inspect.signature(model_class.__call__) + else: + signature = inspect.signature(model_class.forward) + + for p in signature.parameters: + if p == "return_loss" and signature.parameters[p].default is True: + return True + + return False + + def find_labels(model_class): """ Find the labels used by a given model.