Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fine-tuning OpenClip with Hugingface's PEFT (such as LoRA) #761

Closed
KyanChen opened this issue Jul 28, 2023 · 76 comments
Closed

fine-tuning OpenClip with Hugingface's PEFT (such as LoRA) #761

KyanChen opened this issue Jul 28, 2023 · 76 comments
Labels

Comments

@KyanChen
Copy link

Feature request

fine-tuning OpenClip with Hugingface's PEFT (such as LoRA)

Motivation

fine-tuning OpenClip with Hugingface's PEFT (such as LoRA)

Your contribution

refer to https://github.com/KyanChen/MakeMultiHeadNaive/tree/master for help!

@BenjaminBossan
Copy link
Member

Sorry, could you please provide more details? Are you looking for help how to achieve that or are you suggesting that it doesn't work right now?

@KyanChen
Copy link
Author

Now, Hugingface's PEFT (such as LoRA) can not finetune the linear layer of torch.nn.MultiHeadAttention based transformer model (such as OpenCLIP). If I must use the LoRA, I should replace the torch.nn.MultiHeadAttention layer with a self-implemented naive MultiHeadAttention layer. Can you help to integrate it to the official PEFT lib?

@BenjaminBossan
Copy link
Member

I see, thanks for explaining. Indeed, right now, it is impossible as a user to change what type of LoRA layer is being used. We have ideas about exposing a "low level" API that would allow users more fine-grained control, including the possibility to allow using custom layers, as you suggest. I cannot say yet if it will really work out and when it's ready, but I'll let you know.

@duchenzhuang
Copy link

Thanks for your efforts!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@ambroser53
Copy link

I'd like to bump this, being unable to put LoRA weights on anything that uses nn.MultiheadAttention is a real pain and using a naive implementation is clunky and cumbersome. Seems strange that LoRA-Torch can do it but not peft.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Jan 5, 2024

Hey, I created a PR to add MHA: #1324. The implementation was a bit tricky because this layer is not very "friendly" for LoRA-adaptation, but I think I got it working.

For now, this is just a rough draft, so it would be great if you could test it and tell me if it works your use case. To install from this branch, run:

python -m pip install git+https://github.com/BenjaminBossan/peft.git@feat-add-lora-multihead-attention

So far, I did the following testing:

import torch
from torch import nn
import open_clip
from peft import LoraConfig, get_peft_model
from PIL import Image
import requests

model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
peft_model = get_peft_model(model, config)
opt = torch.optim.SGD(peft_model.parameters(), 0.1)

# text encoder
text = tokenizer(["a diagram", "a dog", "a cat"])
text_features = peft_model.encode_text(text)
loss = text_features.sum()
loss.backward()
opt.step()

# image encoder
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
image_features.sum().backward()
opt.step()

@BenjaminBossan
Copy link
Member

@ambroser53 I think the linked LoRA-torch library has some bugs. For instance:

import torch, loratorch
import torch.nn as nn

model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)
loratorch.mark_only_lora_as_trainable(model_torch)
print(model_torch.state_dict().keys())
# prints odict_keys(['weight', 'bias', 'w_lora_A', 'w_lora_B'])

optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)

for _ in range(3):
    model_torch.train()
    x = torch.rand(2, 5)

    loss2 = model_torch(x).sum()
    optimizer_torch.zero_grad()
    loss2.backward()
    optimizer_torch.step()

print(model_torch.state_dict().keys())
# odict_keys(['bias', 'w_lora_A', 'w_lora_B'])
# note the missing 'weight' key!

As you can see, the weight is dropped from the state_dict, making it impossible to save the model. Same is true for named_parameters(). So if you're using this package, you should be aware of this.

@ambroser53
Copy link

Hey @BenjaminBossan cheers for the fork I'll run some tests on Tuesday. I realised that LoRATorch was a bit buggy after I started trying to combine it with peft's LoraLayer but if there's a way to do it without it that'd be much better.

@BenjaminBossan
Copy link
Member

@ambroser53 Did you have time to give it a try?

@ambroser53
Copy link

Hi sorry I meant to get back to you sooner. It appears the layers are placed on the nn.MultiheadAttention blocks just fine on my model. My use case is very complicated though as its a custom architecture so I will need to get back to you on how effective it is and whether the openclip finetuning is bottlenecked or non-performative in some way. Once I have these answers I'll report back.

@BenjaminBossan
Copy link
Member

Great, thanks for testing. Do you have an ETA for when these tests finish?

Regarding performance, I would expect a larger overhead than for simple LoRA layers like Linear because of the merging-unmerging roundtrip we have to take, but I'm not sure if it makes a difference grand scheme of things.

@ambroser53
Copy link

Should get initial results early next week if theres no disasters.

Out of curiousity is said overheard computational or memory?

@BenjaminBossan
Copy link
Member

Should get initial results early next week if theres no disasters.

Thanks!

Out of curiousity is said overheard computational or memory?

It should be computational only. However, since we take the same approach here as LoRA-torch, it shouldn't be better or worse than using that.

@ambroser53
Copy link

I've dug deeper in my testing. Mine is a very specific case where LoRA weights are only placed on specific layers and the model is mixed quantisation so the placement needed further tinkering. However, now that I've specifically made sure which layers are getting where they need to there's a logic error that seems to only occur some of the time. Essentially, say you have nn.MultiheadAttention called attn then it will have the submodule attn.out_proj which is a nn.Linear (or at least it should be, there's this weird NonDynamicQuantisableWhatever going on but lets not get into that). If you have target_modules on your LoraConfig that point to both attn and attn.out_proj then if attn gets turned into a LoraLayer first then when it tries to find attn.out_proj it's now under attn.base_layer.out_proj.

It doesn't look like the out_proj is taken into account by the merge and unmerge which seems to be more to do with the in_proj_weight. In the implementation of nn.MultiheadAttention it doesn't actually use the forward of said out_proj and only passes the weight and bias tensors. I thought this could be fixed just by forcing it to put the LoraLayer on attn.out_proj before attn but I think this would create problems due to the way nn.MultiheadAttention never calls forward which would then neglect the lora weights entirely.

Could there be a simple fix to just do the same as there is on in_proj_weight for out_proj.weight?

@BenjaminBossan
Copy link
Member

Thanks a lot @ambroser53, your analysis is 100% correct. I pushed a new commit to the PR that now takes into account out_proj.

As is, we now apply LoRA to both in_proj and out_proj. There is currently no way to specify only in_proj or only out_proj. That wouldn't be easy to achieve, we would probably have to implement a new argument (or even multiple) on LoraConfig to allow that, which seems a bit overkill for this rather niche feature. My reasoning for applying LoRA to both instead of only in_proj is that recently the consensus seems to converge towards applying LoRA to as many Linear layers as possible. LMK what you think.

I'll be out of office starting next week, so that PR may stall for a while unless one of the other maintainers has time to take over. Still, please try out this new PR and give us feedback if it works for you.

@ambroser53
Copy link

No that sounds perfect I don't think having one or the other would make sense. I should be able to give it a go now and give results next week.

@BenjaminBossan
Copy link
Member

I should be able to give it a go now and give results next week.

Nice. If you can give some early feedback today, I may still have time to react to it :)

@ambroser53
Copy link

This may be a problem with my own complex set up so could be out of scope here but does peft automatically cast parameters to int8 if the underlying model is loaded in int8? Asking since part of the model is in int8 but the rest is skipped via int8_quant_skip_modules this is because now with out_proj implemented it's throwing an error when calling get_peft_model within _restore_weights for lora.MultiheadAttention because registering the out_proj as "weight" seems to have it cast as int8 when it's supposed to have been skipped and left as float16. Have any insights or will mixed quantisation be something wholly unwieldy I'm unlikely to find a quick fix for?

@BenjaminBossan
Copy link
Member

Hmm, normally the weights should not be automatically cast to int8. If you have some way to reproduce this error, I could investigate.

Looking at this issue in general, I think, however, that this implementation will not work correctly with quantized weights. As is, we merge the LoRA weights into the base weights. When the latter are quantized, this requires special treatment, similar to the bnb layers we have for LoRA, a normal merge would surely fail. So I think we would need a completely separate MHA class for quantized layers.

I'm not exactly sure what it is that you're doing with quantization, but as you've remarked earlier, the out_proj actually uses NonDynamicallyQuantizableLinear, which from my understanding exists to prevent some kind of error with quantization. I wonder if that could be related.

@ambroser53
Copy link

I understand that but the point is that the MHA aren't quantised at all. The confusing part is that the MHA and out_proj nn.Linear are being passed to int8_quant_skip_modules. It should be okay for now I'll train on two cards since it can't all fit on one. Hopefully have some results soon.

@BenjaminBossan
Copy link
Member

I understand that but the point is that the MHA aren't quantised at all.

Ah I see, that is indeed very strange and should not happen.

The confusing part is that the MHA and out_proj nn.Linear are being passed to int8_quant_skip_modules

Can you point me to a reference for int8_quant_skip_modules?

@ambroser53
Copy link

Here's the code for bitsandbytesconfig configuration object where you can specify int8_quant_skip_modules but there's no further documentation than what is in the initialisation comment. It does seem to be working as prior to calling get_peft_config the correct modules are in the correct datatype.

I'll try and get together a code sample that reproduces (this code I'm referring to right now is a proprietary for a company)

@ambroser53
Copy link

One more potential bug. It seems that when using get_peft_model on a large model with an MHA inside, it puts the internal parameters (i.e. in_proj_weight and out_proj.weight) in the MHA as requires_grad=True. Its actually really hard to force it it to not be true and I don't quite know why. I wonder whether its because of the nested LoraLayers or something missing in terms of ensuring the base weights dont require gradients that is present in other LoraLayers

@ambroser53
Copy link

It is very bizarre. The following code is from my script. attn_pool.attn is the (only) MHA:

model.base_model.model.model.vision_model.attn_pool.attn.base_layer.in_proj_weight.requires_grad = False
model.base_model.model.model.vision_model.attn_pool.attn.base_layer.out_proj.base_layer.weight.requires_grad = False

trainable_params = [name for name, param in model.named_parameters() if param.requires_grad]

print(model.base_model.model.model.vision_model.attn_pool.attn.base_layer.in_proj_weight.requires_grad)

This outputs true and both the in_proj_weight and out_proj.weight will be in trainable_params. It's almost like iterating through the module names causes the to be made trainable. This doesn't happen with any other parameters in the wrapped model only these two that reside in the MHA.

@ambroser53
Copy link

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 7, 2024

Hi @ambroser53 I'm back in office. Thanks a lot for figuring out this bug and providing a reproducer. I could identify the issue and it should now be fixed. When running your example locally, I now get the correct gradients. Please take a look.

It's almost like iterating through the module names causes the to be made trainable.

This was indeed the case! The reason for this is explained here:

https://github.com/huggingface/peft/pull/1324/files#diff-24a141c266b7b714ae8fcc470f31bc283f7b0f5a671bbf6d5f092741fc374104R899-R903

Here's the code for bitsandbytesconfig configuration object

Sorry, did you mean to include a link here?

@BenjaminBossan
Copy link
Member

@mashijie1028

I found that LoRA does not work for in_proj_weight in attn of open_clip.

This is a consequence of how multihead attention is implemented and one of the reason it is so complicated to apply LoRA to it. in_proj_weight is implemented as a Parameter, not a Module, which is why you can't target it directly like that.

By the way, I download peft as you mentioned before:

Note that the PR you mentioned will target the whole multihead attention layer, not just one of out_proj or in_proj_weight. Take this into account when specifying the target_modules. There is no way to only target the in_proj_weight at the moment.

@mashijie1028
Copy link

@BenjaminBossan
Got it! Thanks for your reply!
I still have one question: how to set the param target_modules in LoraConfig to target the whole multihead attention layer (as you mentioned)?

Previously, I tried with target_modules=["attn"], the codes works fine when adding LoRA, but when I merge back the LoRA to the original clip via peft_model.merge_and_unload(), there are still the keys 'attn.out_proj.lora_A.default.weight' and 'attn.out_proj.lora_B.default.weight' in peft_model.state_dict(), which means that the merge operation still does not work.

Could you please provide a demo code for how to set the LoraConfig and how to merge back the LoRA layer for the whole multihead attention layer?

@BenjaminBossan
Copy link
Member

I can investigate this issue, but I need the code for this. Could you provide reproducer for this please? I only need the model initialization and merging, no need for the data and training part.

@mashijie1028
Copy link

@BenjaminBossan
Hi! Below is my demo code:

import open_clip
from peft import LoraConfig, get_peft_model
from peft.tuners.lora.layer import MultiheadAttention as PeftMha

lora_config = LoraConfig(
    r=16,
    target_modules=["attn"],
    lora_alpha=32,
    lora_dropout=0.05
)

model, preprocess = open_clip.create_model_from_pretrained(model_name='ViT-L-14-quickgelu', pretrained="PATH-TO-YOUR-MODEL")
tokenizer = open_clip.get_tokenizer('ViT-L-14-quickgelu')

peft_model = get_peft_model(model, lora_config)
print(len([m for m in peft_model.modules() if isinstance(m, PeftMha)]))   # 36
peft_model.print_trainable_parameters()   # trainable params: 3,244,032 || all params: 430,860,545 || trainable%: 0.7529

peft_model.merge_and_unload()
#peft_model.merge_adapter()
print(peft_model.state_dict().keys())

In my code, I use MetaCLIP via ViT-L-14-quickgelu. After LoRA merging via merge_and_unload(), I print the state_dict() of the merged model and find that some LoRA keys are still in it.

@BenjaminBossan
Copy link
Member

Thanks a lot for providing the reproducer. There was indeed a bug in the code, it should now be fixed. Could you try again based on the latest commit?

Btw., this line peft_model.merge_and_unload() in your code should be changed to unloaded_model = peft_model.merge_and_unload() and then you should check unloaded_model. merge_and_unload() is not completely in-place.

@mashijie1028
Copy link

mashijie1028 commented Oct 22, 2024

@BenjaminBossan
Hi! Sorry for the late reply. I re-installed your latest peft commit and checked the keys in state_dict() carefully.
Unfortunately, I found that there might be some bugs after the merge operation. Below is my demo code:

import open_clip
import requests
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from PIL import Image
from peft.tuners.lora.layer import MultiheadAttention as PeftMha


lora_config = LoraConfig(
    r=16,
    target_modules=["attn"],
    lora_alpha=32,
    lora_dropout=0.05
)

model, preprocess = open_clip.create_model_from_pretrained(model_name='ViT-L-14-quickgelu', pretrained="CLIP-PATH")

# original model
print(len(model.state_dict().keys()))   # 446
print(len(model.visual.state_dict().keys()))   # 296

# add LoRA
peft_model = get_peft_model(model, lora_config)
print(len(peft_model.state_dict().keys()))   # 590
print(len(peft_model.visual.state_dict().keys()))   # 392
print(peft_model.visual.state_dict().keys())

# merge LoRA
merged_model = peft_model.merge_and_unload()
print(len(merged_model.state_dict().keys()))   # 374
print(len(merged_model.visual.state_dict().keys()))   # 248

As the results show, there are 248 keys in ViT of CLIP after merging, but the original number is 296. When printing the keys, I found that attn.in_proj_weight and attn.out_proj.weight are missing after merging. (with 2 for each layer, so 48 in total considering MetaCLIP ViT has 24 layers).

Could you please fix this?

@BenjaminBossan
Copy link
Member

Thanks for the report. I pushed a new change to the branch that should fix it. Testing your snippet locally, I get the same values now after unloading.

@mashijie1028
Copy link

Thanks for your contribution.
I checked the output features and keys in state_dict() and found the value remains consistent before and after LoRA merging. I think the problem has been fixed now. Appreciate it!
By the way, will these commits be merged into the future peft?

@BenjaminBossan
Copy link
Member

Thanks for testing @mashijie1028.

By the way, will these commits be merged into the future peft?

Yes, it's planned. Right now there is a blocker that prevents MHA to work with low_cpu_mem_usage=True, which I'll have to figure out how to solve. After that, it can hopefully be merged.

@mashijie1028
Copy link

Cheers! Thanks again for your contribution to the community! Hope everything goes well.

@Baijiong-Lin
Copy link

Baijiong-Lin commented Nov 3, 2024

@ambroser53 I think the linked LoRA-torch library has some bugs. For instance:

import torch, loratorch
import torch.nn as nn

model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)
loratorch.mark_only_lora_as_trainable(model_torch)
print(model_torch.state_dict().keys())
# prints odict_keys(['weight', 'bias', 'w_lora_A', 'w_lora_B'])

optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)

for _ in range(3):
    model_torch.train()
    x = torch.rand(2, 5)

    loss2 = model_torch(x).sum()
    optimizer_torch.zero_grad()
    loss2.backward()
    optimizer_torch.step()

print(model_torch.state_dict().keys())
# odict_keys(['bias', 'w_lora_A', 'w_lora_B'])
# note the missing 'weight' key!

As you can see, the weight is dropped from the state_dict, making it impossible to save the model. Same is true for named_parameters(). So if you're using this package, you should be aware of this.

@BenjaminBossan @ambroser53 Hi, I have fixed this problem (Baijiong-Lin/LoRA-Torch@e3e20a0). Could you test it again?
You only need to add loratorch.register_model_param_after_backward(model_torch) after optimizer_torch.step().

@2502128021
Copy link

@BenjaminBossan , hi, there are both "self attention" and "cross attention" in my custom model, I implement them using torch.nn.MultiheadAttention. I have tested "self attention" works with your PR, but if "cross attention" modules are included in target_modules, it will throw an error:
"Only same embed for query/key/value is supported as of now for MultiheadAttention."
Mention, the only differences between "self attention" and "cross attention" modules are inputs of q,k,v. where q, k, v in "self attention" are from the same objective and k,v in "cross attention" are from the same objective and q is not.
how can I sovle this problem?

@BenjaminBossan
Copy link
Member

@2502128021 You're getting this error because your key and value have different embedding dimensions from the query. The issue with that is that PyTorch's MultiheadAttention implementation takes a completely separate code paths when this happens, which is just not supported right now (and also not easy to support). My goal is to get the MHA code for the case of same embedding dimensions merged first, then tackle the other code path later.

@Baijiong-Lin
Copy link

@2502128021 You can try LoRA-Torch, which supports different embedding dimensions between key/value and query in nn.MultiheadAttention.

@TonyTeng66
Copy link

TonyTeng66 commented Dec 22, 2024

@BenjaminBossan Hi, thank you for your code. I was using peft 0.12.0 with your implementation and everything worked will. However, after I updated to peft 0.14.0, when i do lora_model.save_pretrained(model_path), it throws this error:
TypeError: argument of type 'Config' is not iterable.
I wonder how can I solve this problem? I am using python 3.12.2.

@BenjaminBossan
Copy link
Member

@TonyTeng66 Could you please provide a code snippet to reproduce the error? Also, pasting the full error message would be helpful.

@BenjaminBossan
Copy link
Member

Update everyone: The PR #1324 has finally been merged. This means you can now install PEFT from main and apply LoRA to multihead attention layers. If you have time to test it out and report back, this would be great. This way, we can do additional fixes, if necessary, before the next PEFT release.

@TonyTeng66
Copy link

@BenjaminBossan Hi, I tried after install PEFT from main. This error does not occur again. Thank you for the contribution!

@BenjaminBossan
Copy link
Member

Great, thanks for testing. I'll close this issue then, but if you anyone does encounter errors with MHA, please comment here and we can re-open it. Thanks everyone for your patience and your contributions.

@ztjhz
Copy link

ztjhz commented Jan 16, 2025

@BenjaminBossan Hi, I tested the PEFT on the multihead attention and it successfully trained. However, when I load the checkpoint after training for evaluation, the following error occurs:

Traceback (most recent call last):
  File "multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "online_evaluator_worker.py", line 57, in start_worker
    worker.distribute_evaluate(agent, tasks_queue, results_queue)
  File "online_evaluator_worker.py", line 503, in distribute_evaluate
    sample_result = self.evaluate_on_task(task=task, agent=agent, worker_id=self.worker_id)
  File "online_evaluator_worker.py", line 263, in evaluate_on_task
    action, probs = agent.get_action(observations, goal)
  File "early_fusion_tsfm_models.py", line 449, in get_action
    embedded_features, _ = self.model.get_input_embedding_per_timestep(
  File "early_fusion_tsfm_models.py", line 122, in get_input_embedding_per_timestep
    visual_feats, text_feats = self.visual_encoder(
  File "torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "text_cond_visual_encoder.py", line 249, in forward
    fused_feats = self.fusion_xformer(torch.cat(input_features, 1))
  File "torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 315, in forward
    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
  File "torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 536, in forward
    self.self_attn.in_proj_weight,
  File "torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'MultiheadAttention' object has no attribute 'in_proj_weight'

I also noticed that the model weights are now saved differently:

Before:
decoder.layers.0.self_attn.in_proj_weight
decoder.layers.0.self_attn.in_proj_bias
decoder.layers.0.self_attn.out_proj.weight
decoder.layers.0.self_attn.out_proj.bias

After:
base_model.model.decoder.layers.0.self_attn.base_layer.in_proj_bias
base_model.model.decoder.layers.0.self_attn.base_layer.in_proj_weight
base_model.model.decoder.layers.0.self_attn.base_layer.out_proj.base_layer.bias
base_model.model.decoder.layers.0.self_attn.base_layer.out_proj.base_layer.weight
base_model.model.decoder.layers.0.self_attn.base_layer.out_proj.lora_A.default.weight
base_model.model.decoder.layers.0.self_attn.base_layer.out_proj.lora_B.default.weight

@BenjaminBossan
Copy link
Member

@ztjhz Thanks for reporting. Would it be possible to share a reproducer for the error? Note that it is expected that the names of the parameters changes when applying PEFT.

@ztjhz
Copy link

ztjhz commented Jan 19, 2025

@BenjaminBossan I have been debugging for a while and finally found the issue. The error occurs when model.eval() is called. Here is a mini reproducer for the error:

import torch
import torch.nn as nn

from peft import LoraConfig, get_peft_model


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=100, nhead=4, batch_first=True),
            num_layers=2,
        )

        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=100)
        self.fc = nn.Linear(100, 1)

    def forward(self, src):
        src = self.embedding(src)
        output = self.encoder(src)
        output = self.fc(output)
        return output


net = Net()

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["self_attn", "multihead_attn", "linear1", "linear2"],
    modules_to_save=[],
)

lora_model = get_peft_model(net, peft_config)
lora_model.eval()  # the issue here

_input = torch.randint(0, 10, (5, 7))

output = lora_model(_input)

I think the issue is because of line 514 of torch.nn.transformer.py where they check for whether to use the fast path:

        why_not_sparsity_fast_path = ''
        if not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first :
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif not self.self_attn._qkv_same_embed_dim :
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"
        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
                self.norm1.weight,
                self.norm1.bias,
                self.norm2.weight,
                self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

When model.eval() is called, fast path is enabled which causes the bug. When I manually disable fast path, the evaluation works.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Jan 20, 2025
See initial report here:
huggingface#761 (comment).

For MHA to work in all circumstances, for instance in eval model, it
requires us to expose a couple of more attributes that we have missed so
far. Those were added now.
@BenjaminBossan
Copy link
Member

@ztjhz Thanks a lot for the reproducer. There was indeed an error caused by missing attributes, it should be fixed via #2335. If you have the opportunity, please check if that branch solves your initial issue.

BenjaminBossan added a commit that referenced this issue Jan 20, 2025
See initial report here:
#761 (comment).

For MHA to work in all circumstances, for instance in eval model, it
requires us to expose a couple of more attributes that we have missed so
far. Those were added now.
@ztjhz
Copy link

ztjhz commented Jan 22, 2025

@BenjaminBossan Thanks! It works now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests