Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Dec 5, 2022

What does this PR do?

As reported in #20390, the dtype of the weights after from_pretrained is used for a checkpoint is inconsistent between device_map=None or device_map set:

  • device_map=None (which uses nn.Module.laod_state_dict) will have the dtype of the model stay the same, even if the checkpoints are in a different dtype (so loading a float16 checkpoint in a float32 model gives a float32 model)
  • device_map set (which manually sets the parameters) will change the dtype of the model to the dtype of the checkpoint (so loading a float16 checkpoint in a float32 model gives a float16 model).

This PR addresses this.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 5, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing and making model loading consistent between device_map=auto (or any) and device_map=None !
Just wondering if you need a special safety checker for safetensors (bear in mind that I am not very knowledgable about safetensors - what happens if old_params.dtype == torch.float16 and is_safetensors==True)

@sgugger
Copy link
Collaborator Author

sgugger commented Dec 5, 2022

There is no more safetensors at this stage, (is_safetensors means the checkpoint comes from safetensors, but the state dict is a dictionary name to parameter in this case as well).

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants