Skip to content

Conversation

@garrett361
Copy link
Contributor

This PR adds the ability to perform padding-free SFT, which reduces the memory costs and increases throughput when per_device_train_batch_size>1. The model itself must support padding-free training to properly use this feature (Llama and bamba models support this, instance; see huggingface/transformers#35861 for some typical throughput improvements).

Pass --padding-free True to use.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 30, 2025

Is this packing?

@garrett361
Copy link
Contributor Author

Yeah, think it's the same @vwxyzjn .

Basically rather than padding out uneven examples which are batched together, you just concatenate them and add additional info to say where example boundaries are.

@garrett361
Copy link
Contributor Author

Same idea as here

@garrett361
Copy link
Contributor Author

FYI I'm double checking a possible issue with this at the moment; will report back after I have more info.

@garrett361 garrett361 force-pushed the padding-free-squashing-1 branch from 214853f to 447b05b Compare July 3, 2025 18:43
@garrett361 garrett361 force-pushed the padding-free-squashing-1 branch from 447b05b to ebeefe1 Compare July 3, 2025 18:44
@garrett361
Copy link
Contributor Author

I'm double checking a possible issue

Okay, I was debugging some bad-looking padding-free SFT training curves, but it turned out to be an issue with the mamba kernels which was recently fixed here. So, not an issue on the part of this commit.

@garrett361
Copy link
Contributor Author

Some plots of padding-free training. All curves use 8xA100s and have --per_device_train_batch_size 2. The only difference between any pair of curves on a plot is whether --padding_free is true or false.

First, tuluv3 training on a BambaForCausalLM model. Loss curves are nearly identical (they will never be precisely the same due to numerics), and the throughput of the --padding_free true run is ~40% higher.

Scherm­afbeelding 2025-07-03 om 3 17 41 PM

And then the same experiment with LlamaForCausalLM. Similar results: loss curves are similar and --padding_free true througput is ~40% higher.

Scherm­afbeelding 2025-07-03 om 3 18 08 PM

CC @vwxyzjn @hamishivi

Copy link
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! So it seems like if the concatted sequence goes over the maximum supported length of the model, there might be issues? I guess in the ppo/grpo packing implementations, our code instead tries to fit samples into a given max length, and does use multiple microbatches if it has to.

Generally happy to merge this in since its optional, just also might be nice to name it packing since I think that's a more common term that people will understand.


sync_each_batch: bool = False
"""Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs."""
padding_free: bool = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: could we call this packing instead of padding_free? To make it clearer to end-users what the feature is.

Suggested change
padding_free: bool = field(
packing: bool = field(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

model.gradient_checkpointing_enable()

# DataLoaders creation:
if args.padding_free:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.padding_free:
if args.packing:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@garrett361
Copy link
Contributor Author

garrett361 commented Jul 8, 2025

So it seems like if the concatted sequence goes over the maximum supported length of the model, there might be issues?

Maybe? I'm not sure what the max supported length affects, apart from RoPE embeddings? Or I guess similar mechanisms like alibi or just fixed positional embeddings. For the models I've used (llama, bamba) padding-free/packing with, there's no problem because the explicitly sequence dependent bits adjust accordingly. Certainly this only works for some model classes, which is why the collator raises a warning.

I guess in the ppo/grpo packing implementations, our code instead tries to fit samples into a given max length, and does use multiple microbatches if it has to.

I'm only very familiar with the SFT finetune.py script at the moment, so can't yet speak to how you might use this for ppo/grpo packing.

@hamishivi hamishivi merged commit e75f1f2 into allenai:main Jul 8, 2025
3 checks passed
@fabianlim fabianlim mentioned this pull request Jul 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants