-
Notifications
You must be signed in to change notification settings - Fork 490
Padding-free SFT #740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Padding-free SFT #740
Conversation
|
Is this packing? |
|
Yeah, think it's the same @vwxyzjn . Basically rather than padding out uneven examples which are batched together, you just concatenate them and add additional info to say where example boundaries are. |
|
Same idea as here |
|
FYI I'm double checking a possible issue with this at the moment; will report back after I have more info. |
214853f to
447b05b
Compare
447b05b to
ebeefe1
Compare
Okay, I was debugging some bad-looking padding-free SFT training curves, but it turned out to be an issue with the |
|
Some plots of padding-free training. All curves use 8xA100s and have First, tuluv3 training on a BambaForCausalLM model. Loss curves are nearly identical (they will never be precisely the same due to numerics), and the throughput of the
And then the same experiment with LlamaForCausalLM. Similar results: loss curves are similar and
|
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! So it seems like if the concatted sequence goes over the maximum supported length of the model, there might be issues? I guess in the ppo/grpo packing implementations, our code instead tries to fit samples into a given max length, and does use multiple microbatches if it has to.
Generally happy to merge this in since its optional, just also might be nice to name it packing since I think that's a more common term that people will understand.
open_instruct/finetune.py
Outdated
|
|
||
| sync_each_batch: bool = False | ||
| """Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs.""" | ||
| padding_free: bool = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nit: could we call this packing instead of padding_free? To make it clearer to end-users what the feature is.
| padding_free: bool = field( | |
| packing: bool = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, that is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
open_instruct/finetune.py
Outdated
| model.gradient_checkpointing_enable() | ||
|
|
||
| # DataLoaders creation: | ||
| if args.padding_free: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if args.padding_free: | |
| if args.packing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Maybe? I'm not sure what the max supported length affects, apart from RoPE embeddings? Or I guess similar mechanisms like alibi or just fixed positional embeddings. For the models I've used (llama, bamba) padding-free/packing with, there's no problem because the explicitly sequence dependent bits adjust accordingly. Certainly this only works for some model classes, which is why the collator raises a warning.
I'm only very familiar with the SFT |


This PR adds the ability to perform padding-free SFT, which reduces the memory costs and increases throughput when
per_device_train_batch_size>1. The model itself must support padding-free training to properly use this feature (Llama and bamba models support this, instance; see huggingface/transformers#35861 for some typical throughput improvements).Pass
--padding-free Trueto use.