Skip to content

Conversation

@pszemraj
Copy link
Contributor

Signed-off-by: peter szemraj [email protected]

What does this PR do?

This PR adds accelerate support for the longT5 models (i.e., make it possible to use device_map="auto"), so these models can be loaded in 8bit using load_in_8bit=True.

This helps enable inference with trained/fine-tuned SoTA long summarization models using limited memory ☺️

Took inspiration from reviewing similar PRs for other models: #19912 and #19927

cc @sgugger

test results

I made a Colab notebook that clones the branch from my fork to demo the load_in_8bit=True working. Everything else is the same for comparison purposes (except the function that says the model size) as the fp32/standard notebook listed on my fine-tuned model card.

I also ran the tests for longT5 locally:

$ python -m pytest -n auto --dist=loadfile -s -v tests/models/longt5/test_modeling_longt5.py 

( ... many things here ...)

=================================================== 196 passed, 58 skipped, 118 warnings in 30.49s ===================================================

@pszemraj
Copy link
Contributor Author

cc @KMFODA for inputs on tests & more 🤞

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2022

The documentation is not available anymore as the PR was closed or merged.

@pszemraj pszemraj marked this pull request as ready for review November 21, 2022 01:59
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR! Glad to see that 8-bit integration is gaining interest and attention on more models 🔥
Just a small typo on the Google Colab: the .cuda() is not needed after instantiating the model with load_in_8bit and device_map=auto, so I would advice to remove it ;)

Can you make sure the slow tests pass with the command RUN_SLOW=1 pytest tests/models/longt5/test_modeling_longt5.py ? (You will need to have access to a GPU instance) - When I ran your fix, accelerate tests were failing. You can fix them by adding the lines here as it was done for BART / NLLB in #19912

self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
    self.embed_tokens.weight = embed_tokens.weight

@pszemraj
Copy link
Contributor Author

Thanks for the feedback & good catch on the Colab! I've updated the notebook - will run and resolve the slow tests/accelerate items later today/tomorrow and revert back 👌

@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 23, 2022

Hey @pszemraj !
How is the integration going 💪 ? Let me know if I can help at some point to debug / make the tests pass ;) !

@younesbelkada
Copy link
Contributor

Hi @pszemraj !
Is it ok if I try to take over the PR? this addition could be very nice to the lib! Let me know what do you think :)

@pszemraj
Copy link
Contributor Author

pszemraj commented Dec 6, 2022 via email

@pszemraj
Copy link
Contributor Author

pszemraj commented Dec 7, 2022

@younesbelkada hey - was trying to get the tests to pass and evaluate further but unfortunately the machine I do have access to a GPU on and can work this was running into some install issues with the dev dependencies for pytest etc

If you're willing to finish this, that would probably be easiest 😅 I'll add the line for accelerate as you suggested and rebase as per the contrib guidelines, feel free to take whatever you find useful :)

@younesbelkada
Copy link
Contributor

Thanks a lot @pszemraj for your great efforts, will have a look ASAP ;) this is definitely in my TODO list

@pszemraj
Copy link
Contributor Author

pszemraj commented Dec 8, 2022

thanks so much! I see you pushed so I will leave you to it (but feel free to let me know if questions or you need me to change anything on my end)

then we can get this bad boi usable on free Colab runtimes :)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm all slow tests pass (single & multi-gpu)!
Thanks so much @pszemraj for your great contribution and patience! Thanks a lot for making Long-T5 models more accessible to anyone
Leaving it now to @sgugger for a final review

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@sgugger sgugger merged commit a3345c1 into huggingface:main Dec 12, 2022
@pszemraj
Copy link
Contributor Author

Thanks for taking it home @younesbelkada! and thanks for the review @sgugger. Happy to help :)

mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* ✨ add accelerate support for LongT5 models

Signed-off-by: peter szemraj <[email protected]>

* fix `accelerate` tests

* Trigger CI test

Signed-off-by: peter szemraj <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants