-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Honor model dtype in load_checkpoint
#920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
muellerzr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that makes sense. I don't particularly think there is a "harm" in silently pushing it out (i.e. don't advertise the bad behavior but let it still pass) in this particular case. If we do care about phasing that out perhaps leave it for a 1.0.0? (Similar to some optimizer bits we have)
|
Actually before merging, could it maybe be better to handle this in |
src/accelerate/utils/modeling.py
Outdated
| break | ||
|
|
||
| if old_param is not None: | ||
| param = param.to(old_param.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this not be better done in set_module_tensor_to_device ? Or maybe additionally add a torch_dtype arg to set_module_tensor_to_device that handles the param correctly if value=param is used?
| else: | ||
| for param_name, param in checkpoint.items(): | ||
| module_name = param_name | ||
| if dtype is not None and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is moved to set_module_tensor_to_device.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adapting!
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
This PR fixes a standing bug where we have a different behavior than PyTorch. In torch, loading a
state_dictinside a model will never change the model's dtype:Currently in Accelerate,
load_checkpointdoes the opposite and when loading a model, it converts it to the dtype of the state dict. This PR addresses that.This PR only contains the fix for now, we have to discuss how to maybe maintain backward compatibility (even if this is a bug fix), because
diffusersmight be relying on this behavior, cc @patrickvonplaten