-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Enable users to use their own loss functions + deal with prefetching for grad accum #34198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, IMO a regression test on the grad norms could be fairly nice!
src/transformers/trainer.py
Outdated
| self.state.num_input_tokens_seen += ( | ||
| torch.sum( | ||
| self.accelerator.gather( | ||
| torch.tensor( | ||
| inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 | ||
| ) | ||
| ) | ||
| ) | ||
| .cpu() | ||
| .item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make this more readable!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean did this one 🫠
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can split in 3-4 lines 🎐
src/transformers/trainer.py
Outdated
| if (self.label_smoother is not None or self.compute_loss is not None) and "labels" in inputs: | ||
| labels = inputs.pop("labels") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmmm if people don't pass a loss, we won't use the model's default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will, it stays in inputs and gets passed to the models forward()
src/transformers/trainer.py
Outdated
| # For now we don't support object detection | ||
| try: | ||
| num_items_in_batch = sum( | ||
| [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already quickly discussed this with Zach, so this is a more general questions to other reviewers:
Would this line be work for all the different task types we support? Specifically, can we always skip the first item in the sequence, i.e. is the [..., 1:] part valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For casual auto regressive models it works but won't work in other ones
src/transformers/trainer.py
Outdated
| self.state.num_input_tokens_seen += ( | ||
| torch.sum( | ||
| self.accelerator.gather( | ||
| torch.tensor( | ||
| inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 | ||
| ) | ||
| ) | ||
| ) | ||
| .cpu() | ||
| .item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can split in 3-4 lines 🎐
danielhanchen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a denominator change in the test case
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to merge!
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…for grad accum (huggingface#34198) * bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
- The basic issue is that since the version of transformers pinned by requirements.txt (4.39.1) and now (4.57.1), a new arg was added to `Trainer.__init__` called `compute_loss_func` (added in 4.54.1). This new arg broke things because T5Trainer in trainer.py uses positional args instead of keyword args, so all of the positional args are now effectively off by one - The fix was to switch from positional args to keyword args to prevent the off-by-one issue - This fix is backwards compatible with 4.39.1 - This issue is also mentioned in jkallini#1 - huggingface/transformers@6ba31a8 - huggingface/transformers#34198 ``` File /opt/homebrew/Cellar/jupyterlab/4.4.5/libexec/lib/python3.13/site-packages/transformers/trainer.py:647, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, processing_class, model_init, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics) 645 self.compute_metrics = compute_metrics 646 self.preprocess_logits_for_metrics = preprocess_logits_for_metrics --> 647 self.optimizer, self.lr_scheduler = optimizers 648 self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs 649 if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None: TypeError: cannot unpack non-iterable NoneType object ```

What does this PR do?
In conjunction with #34191, this PR solves the other half of what's needed:
compute_lossgradient_accumulation_stepsworth of data each complete step and marking how many samples were seen (num_items_in_batch), which can be passed to a loss function if it takes innum_items_seen(name TBD)A bit of feedback needed we need to coordinate:
num_items_in_batchand then passed through to the loss functions as such? Or is there a better name we can think ofFixes huggingface/trl#2175
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @ArthurZucker