Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable grad checkpointing after get_peft_model #2398

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Feb 25, 2025

What does this PR do ?

Fixes huggingface/transformers#35826
This PR enables grad checkpointing even if the model is already converted to a PeftModel. I can add a test if needed, but where should I put it ? I see that you have _test_training_gradient_checkpointing but it is a common test.

The following works:

model = ...
model.enable_gradient_checkpointing()
config = ...
model = get_peft_model(model, config)

but not that (with this PR, it should work now) :

model = ...
config = ...
model = get_peft_model(model, config)
model.enable_gradient_checkpointing() # doesn't return an error or a warning, so the users think that it worked

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks for the PR @SunMarc. For my better understanding, under what circumstances would this come up?

So far, I thought we had subsumed this functionality under prepare_model_for_kbit_training:

def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
r"""
Note this method only works for `transformers` models.
This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32 4- Freezing the base model layers to ensure they are not updated during training
Args:
model (`transformers.PreTrainedModel`):
The loaded model from `transformers`
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
`torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
Note this is only available in the latest transformers versions (> 4.34.1).
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"
is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao"
is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False)
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
if (
not is_gptq_quantized
and not is_aqlm_quantized
and not is_eetq_quantized
and not is_hqq_quantized
and not is_torchao_quantized
):
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
if (
loaded_in_kbit
or is_gptq_quantized
or is_aqlm_quantized
or is_eetq_quantized
or is_hqq_quantized
or is_torchao_quantized
) and use_gradient_checkpointing:
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)
if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
warnings.warn(
"gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
" if you want to use that feature, please upgrade to the latest version of transformers.",
FutureWarning,
)
gc_enable_kwargs = (
{} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
)
# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable(**gc_enable_kwargs)
return model

Maybe that's not a smart idea and having a separate method is preferable, I just need to understand.

@SunMarc
Copy link
Member Author

SunMarc commented Feb 25, 2025

Thanks for the PR @SunMarc. For my better understanding, under what circumstances would this come up?

So far, I thought we had subsumed this functionality under prepare_model_for_kbit_training:

I've linked the issue here : huggingface/transformers#35826
Indeed, it looks like this function can potentially solve the issue but it only works for quantized models no ? In the issue, we are dealing with non quantized model.
I've updated the description so that it is clearer.

@BenjaminBossan
Copy link
Member

Ah yes, I see, sorry for missing that. Indeed, it makes sense to enable the possibility like this since users may not use prepare_model_for_kbit_training.

I checked the existing _test_training_gradient_checkpointing. My first idea was as follows: add a new argument enable_gradient_checkpointing_before_peft=True. Thus, by default, the test would do the same thing as it does now. But when setting the argument to False, gradient checkpointing would only be enabled after calling get_peft_model. From my understanding, this should result in the error.

However, when I tried this, there was no error, even when moving model.gradient_checkpointing_enable() down after get_peft_model. So I'm not quite sure how to trigger the error. Is it required to use Trainer?

@SunMarc
Copy link
Member Author

SunMarc commented Feb 25, 2025

Could you try with the following script ? :

import os
import torch
from transformers import  AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import copy

def main():
    train_data = {"input": "input test", "output": "output test"}
    model_name = "codellama/CodeLlama-13b-Instruct-hf"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    input_ids = tokenizer.encode(train_data["input"])
    output_ids = tokenizer.encode(train_data["output"])
    model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
    model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
    labels = copy.deepcopy(model_inputs_output)
    labels[: len(input_ids)] = -1 # 
    example_mask = model_inputs_output.ge(0)
    label_mask = labels.ge(0)
    model_inputs_output[~example_mask] = 0
    labels[~label_mask] = -100
    train_dataset = {
            "input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
            "attention_mask": example_mask.unsqueeze(0).to("cuda"),
            "labels": labels.unsqueeze(0).to("cuda")
        }

    lora_config = LoraConfig(
            r=8,  
            lora_alpha=16,  
            target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"],  # 与llama-factory一致
            lora_dropout=0.05,  
            task_type= TaskType.CAUSAL_LM  
        )
    model = get_peft_model(model, lora_config)
    model.train()
    model.print_trainable_parameters()
    model.to("cuda")
    model.gradient_checkpointing_enable()

    output = model(**train_dataset)
    loss = output["loss"]
    print(f"loss: {loss.requires_grad}")


if __name__ == "__main__":
    main()

@BenjaminBossan
Copy link
Member

Yes, I can confirm that loss.requires_grad is False. But I don't see how this example differs from the unit test:

peft/tests/testing_common.py

Lines 1214 to 1246 in 3dd2668

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
if config_cls == PrefixTuningConfig:
return pytest.skip(f"Test not applicable for {config_cls}")
if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("AdaLora with RoBERTa does not work correctly")
if (config_cls == OFTConfig) and ("deberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("OFT with Deberta does not work correctly")
model = self.transformers_class.from_pretrained(model_id)
if not getattr(model, "supports_gradient_checkpointing", False):
return pytest.skip(f"Model {model_id} does not support gradient checkpointing")
model.gradient_checkpointing_enable()
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)
inputs = self.prepare_inputs_for_testing()
# check if `training` works
output = model(**inputs)[0]
loss = output.sum()
loss.backward()

(as mentioned, I moved the model.gradient_checkpointing_enable() call)

I checked if it could be the model, or if the loss needs to be calculated by the forward call, but both made no difference. Any idea?

@SunMarc
Copy link
Member Author

SunMarc commented Feb 25, 2025

That's very strange, I don't know either :/ I'll continue to investigate tomorrow then !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

model.gradient_checkpointing_enable() makes loss.requires_grad be False
3 participants