-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Trainer multi label #7191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer multi label #7191
Conversation
| Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| A tuple with the loss, logits and labels (each being optional). | ||
| """ | ||
| has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to test for old deprecated argument names since they have all been changed in the lib and the user can now set their own name if they have an old model they are still using.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good! this will reduce cruft
Codecov Report
@@ Coverage Diff @@
## master #7191 +/- ##
==========================================
- Coverage 80.86% 79.41% -1.46%
==========================================
Files 169 169
Lines 32293 32322 +29
==========================================
- Hits 26115 25668 -447
- Misses 6178 6654 +476
Continue to review full report at Codecov.
|
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, very nice addition.
| def nested_xla_mesh_reduce(tensors, name): | ||
| if is_torch_tpu_available(): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| if isinstance(tensors, (list, tuple)): | ||
| return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) | ||
| return xm.mesh_reduce(name, tensors, torch.cat) | ||
| else: | ||
| raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you get a chance to test this on TPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, was planning to ask you about it this morning.
* Trainer accep multiple labels * Missing import * Fix dosctrings
This is a follow-up from #7126. The same kinds of models that can output multiple predictions expect multiple labels (not named "labels") so the evaluation code needs to be changed for this. To support models built by users, I added a
label_namesfield in theTrainingArgumentswhich contain the label names. It then defaults to["labels"]for most models,["start_positions", "end_positions"]for question answering models if the user does not set it to work seamlessly for all Transformers models.I ended up writing a few util functions that concat/numpify for tensors or nested lists/tuples of tensors to avoid testing everywhere in
Trainer, I think the design is cleaner this way and it also supports model with crazy outputs (if we setoutput_attentions=Truefor instance). I also added a test for the multiple labels predictions.