Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
CONFIG_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
can_return_loss,
find_labels,
get_full_repo_name,
is_apex_available,
Expand Down Expand Up @@ -625,6 +626,7 @@ def __init__(
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

# Internal variables to keep track of the original batch size
Expand Down Expand Up @@ -3190,6 +3192,14 @@ def prediction_step(
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If False, no need to check default value

return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
Expand All @@ -3198,7 +3208,7 @@ def prediction_step(
ignore_keys = []

# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels:
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
Expand All @@ -3208,7 +3218,7 @@ def prediction_step(
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels:
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
Expand All @@ -3226,7 +3236,7 @@ def prediction_step(
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
PaddingStrategy,
TensorType,
cached_property,
can_return_loss,
expand_dims,
find_labels,
flatten_dict,
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,28 @@ def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)


def can_return_loss(model_class):
"""
Check if a given model can return loss.

Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)

for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
return True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will return True only if the default value is True.


return False


def find_labels(model_class):
"""
Find the labels used by a given model.
Expand Down