Skip to content

Conversation

gabeweisz
Copy link

@gabeweisz gabeweisz commented Sep 11, 2025

This PR implements the feature request described in #1000

MaxText and MaxDiffusion (and probably others) use TransformerEngine for Flash Attention on GPUs.
When using packed inputs in THD format, TransformerEngine requires that the user specify the maximum segments packet into a sequence.

Grain did not previously support specifying maximum segment per sequence, which would cause data corruption in TransformerEngine if the limit was exceeded.

This PR allows the user to specify max segments per sequence in both the FirstFitPackIterDataset class and the deprecated PackAndBatchOperation class and includes tests for both.

PR #1028 mostly makes this PR obsolete because #1028 allows the user to override the packing and implement #1000 without duplicating a lot of code from grain, but this PR will still be useful because multiple consumers of grain need this feature and will not need to individually implement it.

Should #1028 be accepted, I'm happy to integrate this functionality with that one. I think that PR is great from the end-user perspective.


📚 Documentation preview 📚: https://google-grain--1039.org.readthedocs.build/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant