Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need to explicitly set use_reentrant when calling checkpoint #26969

Closed
4 tasks
FartyPants opened this issue Oct 20, 2023 · 20 comments
Closed
4 tasks

Need to explicitly set use_reentrant when calling checkpoint #26969

FartyPants opened this issue Oct 20, 2023 · 20 comments

Comments

@FartyPants
Copy link

System Info

windows

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

according to new pytorch, you need to now explicitly set use_reentrant as it will be changed from use_reentrant=True to use_reentrant=False in near future

transformers.models.llama.modeling_llama
def forward...

            layer_outputs = torch.utils.checkpoint.checkpoint(
                create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
            )

Expected behavior

need to explicitly set use_reentrant

@ArthurZucker
Copy link
Collaborator

cc @fxmarty would you like to have a look at this? 😉

@ArthurZucker
Copy link
Collaborator

Seems like @younesbelkada also needs this in #26917

@IbrahimAmin1
Copy link
Contributor

IbrahimAmin1 commented Nov 13, 2023

You can set it explicitly in the training_args arguments by using the gradient_checkpointing_kwargs argument

training_args = TrainingArguments(
        # Arguments
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':False} # OR gradient_checkpointing_kwargs={'use_reentrant':True} 
        # Arguments
)

@GrahamEckel
Copy link

FYI, this solution does not work when using SFTTrainer() from trl as the parameter is not exposed.

@younesbelkada
Copy link
Contributor

@GrahamEckel can you elaborate on the issue you face with TRL SFTTrainer? Ideally with a small reproducer 🙏

@Vectorrent
Copy link

Are we able to fix this when NOT using the trainer? I tried passing gradient_checkpointing_kwargs={'use_reentrant':False} to model.gradient_checkpointing_enabled(), but it just bombs-out with a "use_reentrant is an unrecognized argument" error.

I'm currently on Transformers 4.35.2.

@younesbelkada
Copy link
Contributor

@LuciferianInk which model are you using?

model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

Should work for all standard transformers model. We also have CI tests for that: https://github.com/huggingface/transformers/blob/main/tests/test_modeling_common.py#L575 and

self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False})

@Vectorrent
Copy link

Oops, syntax error. Sorry for the false alarm. With your example, I was able to fix that!

@younesbelkada
Copy link
Contributor

Awesome, thanks !

@manmax31
Copy link

manmax31 commented Dec 15, 2023

I am trying to finetune mistral 7b using SFT and PEFT, but i get the following error when I have gradient_checkpointing=True
ValueError: Attention mask should be of size (1, 1, 2700, 5400), but is torch.Size([1, 1, 2700, 2700])

I have tried gradient_checkpointing=True and gradient_checkpointing_kwargs={"use_reentrant": True} and I still get the above error.

These are the versions I have:
Transformers version: 4.36.1
PEFT version: 0.7.1
TRL version: 0.7.4

Here is my code:

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
if torch.cuda.device_count() > 1:  # If more than 1 GPU
    model.is_parallelizable = True
    model.model_parallel = True

training_args = TrainingArguments(
    output_dir="models",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=1.41e-5,
    logging_steps=1,
    num_train_epochs=1,
    # max_steps=100,
    report_to=None,
    save_steps=30,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=10,
    do_eval=True,
    greater_is_better=False,
    load_best_model_at_end=True,
    auto_find_batch_size=True,
    optim="paged_adamw_8bit",
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,  # Leads to reduction in memory at slighly decrease in speed
    gradient_checkpointing_kwargs={"use_reentrant": True},
)

# LoraConfig
peft_config = LoraConfig(
    r=32,
    lora_alpha=32, 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj"],
)

early_stop = EarlyStoppingCallback(10)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=peft_config,
    max_seq_length=2700, 
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    packing=True,
    neftune_noise_alpha=5,
    callbacks=[early_stop],
)

trainer.train()

@younesbelkada
Copy link
Contributor

Hi @manmax31
The issue is fixed by #28031 please see my comment here: #28056 (comment)
Can you try out with transformers main? pip install -U git+https://github.com/huggingface/transformers

@manmax31
Copy link

manmax31 commented Dec 15, 2023

Thank you. Is this fix not in pypi yet?
As that's only way our systems can access it.

@younesbelkada
Copy link
Contributor

cc @ArthurZucker @amyeroberts would it makes sense to do a patch release to include #28031 ? it fixes a regression issue - i.e. users were able to train as usual with PEFT and GC before introducing the attention refactor and #28031 fixes it

@manmax31
Copy link

That will be great. I am currently now back to 4.35.2

@amyeroberts
Copy link
Collaborator

@younesbelkada If it's a regression, then yes, I think we should do a patch release (also including #28043 and #28061) cc @ArthurZucker WDYT?

@ArthurZucker
Copy link
Collaborator

Yes 👍🏻

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 10, 2024

Was fixed and released so closing

hijkzzz pushed a commit to OpenRLHF/OpenRLHF that referenced this issue Feb 5, 2024
* fix bug: generate_args-do_sample

* fix gradient_checkpointing_kwargs bug

see: huggingface/trl#912 and huggingface/transformers#26969

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@StephennFernandes
Copy link

@ArthurZucker is this issue fixed, still facing same issue even after fresh release installed from source

@ArthurZucker
Copy link
Collaborator

Could you open a new issue, with a fresh reproducer, the output of transformers-cli env and the full traceback? 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants