-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Add accelerate support for LongT5 models
#20341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @KMFODA for inputs on tests & more 🤞 |
|
The documentation is not available anymore as the PR was closed or merged. |
Signed-off-by: peter szemraj <[email protected]>
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool PR! Glad to see that 8-bit integration is gaining interest and attention on more models 🔥
Just a small typo on the Google Colab: the .cuda() is not needed after instantiating the model with load_in_8bit and device_map=auto, so I would advice to remove it ;)
Can you make sure the slow tests pass with the command RUN_SLOW=1 pytest tests/models/longt5/test_modeling_longt5.py ? (You will need to have access to a GPU instance) - When I ran your fix, accelerate tests were failing. You can fix them by adding the lines here as it was done for BART / NLLB in #19912
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
|
Thanks for the feedback & good catch on the Colab! I've updated the notebook - will run and resolve the slow tests/accelerate items later today/tomorrow and revert back 👌 |
|
Hey @pszemraj ! |
|
Hi @pszemraj ! |
|
Hey! let me give it a stab today (I was sick for a week) if you don't see anything by tomorrow, feel free to take it home!
Email | ***@***.***
On 12/6/2022 8:54:39 AM, Younes Belkada ***@***.***> wrote:
Hi @pszemraj [https://github.com/pszemraj] !
Is it ok if I try to take over the PR? this addition could be very nice to the lib! Let me know what do you think :)
—
Reply to this email directly, view it on GitHub [#20341 (comment)], or unsubscribe [https://github.com/notifications/unsubscribe-auth/AR3GSMFN4MP444ZC72B4EN3WL3WL7ANCNFSM6AAAAAASGEAOLE].
You are receiving this because you were mentioned.Message ID: ***@***.***>
[31e14b4b-28c3-4714-8081-803278962750]
|
|
@younesbelkada hey - was trying to get the tests to pass and evaluate further but unfortunately the machine I do have access to a GPU on and can work this was running into some install issues with the If you're willing to finish this, that would probably be easiest 😅 I'll add the line for accelerate as you suggested and rebase as per the contrib guidelines, feel free to take whatever you find useful :) |
|
Thanks a lot @pszemraj for your great efforts, will have a look ASAP ;) this is definitely in my TODO list |
|
thanks so much! I see you pushed so I will leave you to it (but feel free to let me know if questions or you need me to change anything on my end) then we can get this bad boi usable on free Colab runtimes :) |
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
|
Thanks for taking it home @younesbelkada! and thanks for the review @sgugger. Happy to help :) |
* ✨ add accelerate support for LongT5 models Signed-off-by: peter szemraj <[email protected]> * fix `accelerate` tests * Trigger CI test Signed-off-by: peter szemraj <[email protected]> Co-authored-by: younesbelkada <[email protected]>
Signed-off-by: peter szemraj [email protected]
What does this PR do?
This PR adds
acceleratesupport for the longT5 models (i.e., make it possible to usedevice_map="auto"), so these models can be loaded in 8bit using load_in_8bit=True.This helps enable inference with trained/fine-tuned SoTA long summarization models using limited memory☺️
Took inspiration from reviewing similar PRs for other models: #19912 and #19927
cc @sgugger
test results
I made a Colab notebook that clones the branch from my fork to demo the
load_in_8bit=Trueworking. Everything else is the same for comparison purposes (except the function that says the model size) as the fp32/standard notebook listed on my fine-tuned model card.I also ran the tests for
longT5locally:$ python -m pytest -n auto --dist=loadfile -s -v tests/models/longt5/test_modeling_longt5.py ( ... many things here ...) =================================================== 196 passed, 58 skipped, 118 warnings in 30.49s ===================================================