Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 5, 2023

What does this PR do ?

This PR fixes two issues user can have when using big inference model:

  • Use their own device map but forget that parameters that are tied together should be on the same device. We return an error showing which parameters should be on the same device
  • Forget to tie the parameters before using infer_auto_device_map() which can create a bad device_map. We also return an error asking to tie the weights before using this function.

How to test it

Issue 1

import os
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from accelerate.utils import find_tied_parameters

checkpoint = "facebook/opt-350m"

device_map_work = {'model.decoder.embed_tokens': 'cpu',
 'model.decoder.embed_positions': 'cpu',
 'model.decoder.project_out': 'cpu',
 'model.decoder.project_in': 'cpu',
 'model.decoder.layers': 'cpu',
 'lm_head': 'cpu'}

device_map_do_no_work = {'model.decoder.embed_tokens': 'cpu',
 'model.decoder.embed_positions': 'cpu',
 'model.decoder.project_out': 'cpu',
 'model.decoder.project_in': 'cpu',
 'model.decoder.layers': 'cpu',
 'lm_head': 'disk'}

model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map = device_map_do_no_work, offload_folder="offload",offload_state_dict = True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Issue 2

import torch
from transformers import AutoConfig,AutoModelForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch

checkpoint = "facebook/opt-350m"

config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)
device_map = infer_auto_device_map(model, no_split_module_classes=["OPTDecoderLayer"])

@SunMarc SunMarc requested a review from sgugger June 5, 2023 15:16
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for those QOL improvements!

SunMarc and others added 2 commits June 5, 2023 12:02
Fix log

Co-authored-by: Sylvain Gugger <[email protected]>
Fix comments and tests

Fix description
has_tied_encoder_decoder = False
has_tied_module = False

if transformers.modeling_utils.PreTrainedModel in inspect.getmro(model.__class__):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking on testing the class __name__ to avoid the extra dep on Transformers.

@SunMarc SunMarc merged commit b9628f1 into huggingface:main Jun 5, 2023
@SunMarc SunMarc deleted the check_tied_parameters branch June 5, 2023 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants