-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Consider top-level buffers when computing infer_auto_device_map
#792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider top-level buffers when computing infer_auto_device_map
#792
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
buffers support when computing infer_auto_device_mapmodel._buffers support when computing infer_auto_device_map
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how some buffers don't end up in model._buffers because I don't fully understand that part.
|
So if I understood it correctly, if you have some modules such as Here is an example that I have quickly tried: |
|
I think in this case, it's just the difference between |
|
Ah yes I see, you're probably right here! Let me dig a bit more and get back to you here |
|
@sgugger I might have more clue on what is failing Let me know what do you think! I guess this failed for |
|
Ah, in this case it looks very much like the problem #747 fixed for top-level parameters, so the fix should be pretty similar here too! |
- use `model.named_buffers(recurse=False)` instead Co-authored-by: Sylvain Gugger <[email protected]>
model._buffers support when computing infer_auto_device_mapinfer_auto_device_map
infer_auto_device_map infer_auto_device_map
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks!
|
The whole testing suite (including slow tests) is green! 🟢 Merging ! |
What does this PR do?
This PR adds
list(model._buffers)insidemodules_to_treatwhen computing theauto_device_map. This scenario occured when I tried to addacceleratesupport forBART-like models when thefinal_logits_biasis registered as a buffer and is different than auinttype. It seems that we need to assign a device to this buffer.The other solution is to "force-ignore" the buffer in
check_device_maphere since the tensors that are inmodel._buffersare stored in the state_dict.cc @sgugger @muellerzr
slow tests from
tests/test_bigmodeling.pypass!