Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 22, 2022

What does this PR do?

This PR mainly fixes https://github.com/huggingface/transformers/actions/runs/3754402958/jobs/6378652143

Since the PR huggingface/accelerate#920 has been merged, the fix proposed in #20760 seems to not work anymore using the main branch of accelerate for some specific cases.

To reproduce (use the main branch of accelerate):

import torch
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
print(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype)
>>> torch.float16

Why?

I believe this is because the aforementioned PR introduced a new argument dtype on the function set_module_tensor_to_device, if this argument is set to None (by default), the target value is automatically set to the dtype of the old tensor - which slightly breaks some assumptions made in #20760
I believe upstreaming this change on modeling_utils by adding the support of this new argument should be the fix. As some users might not use the latest version of accelerate, I added a small hack to make this change backward compatible, but I am not sure if this is the best solution

Tested this fix on the main branch of accelerate, accelerate==0.15.0 and all relevant tests pass

cc @sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada added the Core: Modeling Internals of the library; Models. label Dec 22, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing! LGTM with just one nit.

force_upcast_dtype = torch.float32

# For backward compatibility with older versions of `accelerate`
if set_module_tensor_to_device.__code__.co_argcount == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight nit: can we use the signature and parameter names using inspect? It would be clearer to read. Also add a TODO that this should become a version check at the next version of Accelerate (I will take care of it after next release).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Should be addressed in 95486c3

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
@younesbelkada younesbelkada merged commit accad48 into huggingface:main Dec 26, 2022
MKhalusova pushed a commit to MKhalusova/transformers that referenced this pull request Dec 28, 2022
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Core: Modeling Internals of the library; Models.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants