Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Nov 14, 2022

What does this PR do?

Allow trainer to give evaluation loss for CLIP-like models.

Currently, this line

has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)

gives has_labels = False for CLIP-like models, and can't give loss value in the evaluation.

without this PR:

***** eval metrics *****
  epoch                   =        1.0
  eval_runtime            = 0:00:01.67
  eval_samples_per_second =      9.571
  eval_steps_per_second   =      4.785

with this PR.

***** eval metrics *****
  epoch                   =        1.0
  eval_loss               =     0.8159
  eval_runtime            = 0:00:01.66
  eval_samples_per_second =      9.598
  eval_steps_per_second   =      4.799

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Nov 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh requested a review from sgugger November 14, 2022 17:35
@ydshieh ydshieh changed the title Allow trainer to return loss for CLIP-like models Allow trainer to return eval. loss for CLIP-like models Nov 14, 2022
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
can_compute_loss = True if len(self.label_names) == 0 and self.can_return_loss else False
Copy link
Collaborator Author

@ydshieh ydshieh Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to restrict to len(self.label_names) == 0.

For models that has len(self.label_names) > 0, we should check if the inputs contain the required labels by the model - which is done one line above for has_labels.

Copy link
Collaborator Author

@ydshieh ydshieh Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can_compute_loss actually means can_compute_loss_without_labels, but maybe a too long name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename all can_compute_loss to loss_without_labels then? It's more informative even if not completely perfect.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice idea thanks a lot! I left a few comments regarding naming and we can make the function can_return_loss a bit better, but overall great PR!

"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
can_compute_loss = True if len(self.label_names) == 0 and self.can_return_loss else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename all can_compute_loss to loss_without_labels then? It's more informative even if not completely perfect.

signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
return [p for p in signature.parameters if p == "return_loss"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if this returns a bool then rely on Python magic conversion to bools.
Also I think we should check if the default is True as the Trainer won't change the default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to

any(p == "return_loss" for p in signature.parameters)

but do you mean we should have sth. like (conceptually)

any(p == "return_loss" and default_value(p) is True for p in signature.parameters)

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to the second!

"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
loss_without_labels = True if len(self.label_names) == 0 and self.can_return_loss and inputs.get("return_loss", None) is True else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs will not contain return_loss if True is the default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! I see all of our return_loss have None as default value (include CLIP), but I can add extra check to be sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh in this case then your solution works.

# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If False, no need to check default value


for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
return True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will return True only if the default value is True.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Nov 15, 2022

@sgugger Hopefully the change covers everything that could happen now and in the future.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

@ydshieh ydshieh merged commit 0d0d776 into main Nov 15, 2022
@ydshieh ydshieh deleted the clip_loss branch November 15, 2022 18:47
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
…20214)

* Allow trainer to return loss for CLIP-like models

* Apply suggestions

* update

* update

* update

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants