Skip to content

[Trainer] Force is_model_parallel when model is loaded in multiple GPUs using accelerate#22532

Merged
younesbelkada merged 6 commits into
huggingface:mainfrom
younesbelkada:add-trainer-mp
Apr 3, 2023
Merged

[Trainer] Force is_model_parallel when model is loaded in multiple GPUs using accelerate#22532
younesbelkada merged 6 commits into
huggingface:mainfrom
younesbelkada:add-trainer-mp

Conversation

@younesbelkada

@younesbelkada younesbelkada commented Apr 3, 2023

Copy link
Copy Markdown
Contributor

What does this PR do?

When using the Trainer on a multi-GPU environment, users currently apply a patch that leads to some bugs.
Before running a training they need to call:

setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)

Which can lead to unexpected bugs on some models, such as T5, that has the parallelize API that is still in place, thus when forcing model_parallel to be True, calls that API, which is deprecated and should not be maintained.

Script to reproduce:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling

from peft import prepare_model_for_int8_training,LoraConfig, get_peft_model

causal_lm_model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(
    causal_lm_model_id,
    load_in_8bit=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(causal_lm_model_id)
model = prepare_model_for_int8_training(model)

# setattr(model, 'model_parallel', True)
# setattr(model, 'is_parallelizable', True)  

config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
    model=model,
    train_dataset=data["train"],
    args=TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=3,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.is_model_parallel = True

model.config.use_cache = False
trainer.train()

cc @sgugger

Related: huggingface/peft#205

@sgugger

sgugger commented Apr 3, 2023

Copy link
Copy Markdown
Collaborator

Could you elaborate why is such a patch needed and what is the goal of your PR? Cause all of this seems very hacky.

@younesbelkada

younesbelkada commented Apr 3, 2023

Copy link
Copy Markdown
Contributor Author

These hacks were needed because self.place_model_on_device needs to be set to True in order for the Trainer to work correctly on a multi-GPU environment, i.e. with a model that has been loaded across multiple GPUs (so we're talking about Naive PP here). Otherwise users will encounter device mismatch between model's input/output.

Moreover, modifying place_model_on_device directly on TrainingArguments seems to not work, as this argument seems to not be on the __init__ of that class, and also it seems to me that it is better to not touch this attribute as it is a property method:

def place_model_on_device(self):

That is why I preferred to introduce a new argument to avoid modifying what is already in place and modify directly what is needed to be edited, without having to modify the model's internals (forcing model_parallel to True on T5 models will call the deprecated parallelize API that leads to some bugs)

@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Apr 3, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@sgugger

sgugger commented Apr 3, 2023

Copy link
Copy Markdown
Collaborator

Or you could just analyze the device map of the model and determine if there are several GPUs used. It would be cleaner and not require the user to learn the 97th training argument.

@younesbelkada

Copy link
Copy Markdown
Contributor Author

Ahh yes good point!

Comment thread src/transformers/models/t5/modeling_t5.py Outdated
@younesbelkada younesbelkada changed the title [Trainer] Add is_model_parallel argument [Trainer] Force is_model_parallel when model is loaded in multiple GPUs using accelerate Apr 3, 2023
Comment thread src/transformers/trainer.py Outdated

if (
getattr(model, "hf_device_map", None) is not None
and len(set(model.hf_device_map.values())) > 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe actually check the number of GPUs, cause this could be one GPU and CPU here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Fixed in 5eb72b4

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-device placement should only be on GPUs for naive pipelining to work, right? Offloading to CPU/disk won't work, isn't it the case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think offloading to CPU/disk won't work yes, I am also unsure if CPU/disk offload training works out of the box with accelerate (without DeepSpeed)

Comment thread src/transformers/trainer.py Outdated

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@younesbelkada younesbelkada merged commit cab048f into huggingface:main Apr 3, 2023
@younesbelkada younesbelkada deleted the add-trainer-mp branch April 3, 2023 15:10
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…e GPUs using `accelerate` (huggingface#22532)

* add `is_model_parallel` arg on Trainer

* add warning

* adapt from suggestions

* revert t5 changes

* remove commas

* adapt from suggestions
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…e GPUs using `accelerate` (huggingface#22532)

* add `is_model_parallel` arg on Trainer

* add warning

* adapt from suggestions

* revert t5 changes

* remove commas

* adapt from suggestions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants