Skip to content

Conversation

@younesbelkada
Copy link
Contributor

What does this PR do?

This PR adds accelerate support to M2M100, therefore this enables loading NLLB models in 8-bit using load_in_8bit=True.

This might contain a breaking change but I am not sure.
When initializing the model in the meta device using accelerate the module self.shared is intialized and set to the correct device using set_tensor_to_device thrice - since it is shared by 3 modules (base model, encoder, decoder) - so it somehow ends up being on the meta device.
Therefore manually assigning a new module with the weights that correspond to the weights of the shared module should do the trick. But I am wondering if this is a breaking change since the shared module of the Encoder & Decoder won't be "shared" anymore. It should not be a problem at inference time, but can be problematic when training the model.

cc @sgugger

Also I know T5 also supports accelerate and uses shared embeddings. The only difference I see from both implementations are the _keys_to_ignore_on_load_missing that contains the shared weights for T5 and doesn't contain the shared weights for M2M100

@younesbelkada younesbelkada requested a review from sgugger October 26, 2022 22:42
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

Comment on lines 782 to 783
if embed_pos.device != inputs_embeds.device:
embed_pos = embed_pos.to(inputs_embeds.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the test, it's already done inside the to method to default to a noop :-)

Comment on lines 1017 to 1018
if positions.device != inputs_embeds.device:
positions = positions.to(inputs_embeds.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same her

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants