Skip to content

Commit 079e45a

Browse files
committed
Address changes in transformers VLM architecture (#2554)
[transformers PR #37033](huggingface/transformers#37033) re-arranges the way visual language models are built by moving the LM head from the language model to the top-level VLM (among other things). This breaks the following test: ``` peft_config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=20) model.language_model = get_peft_model(model.language_model, peft_config) ``` Reason being that all soft-prompting methods need a task type since each task type has specific handling of the soft prompt (e.g., padding the labels accordingo to the number of virtual tokens for causal LM). We also can't simply use `task_type='FEATURE_EXTRACTION'` as this would not deal with `labels` either. Luckily the VLM is almost behaving like a LM (e.g., `get_input_embeddings` refers to the underlying LM), therefore we can target the VLM itself and need to have the soft prompt methods detect if we're fine-tuning a VLM so that we take the respective config variables from the `base_model.text_config` instead of `base_model` directly.
1 parent be9bce3 commit 079e45a

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

src/peft/mapping_func.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def get_peft_model(
109109
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
110110
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
111111

112+
# We explicitly exclude prompt learning here since prompt learning is specific to the task and needs special
113+
# handling in the PEFT model's forward method.
112114
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
113115
return PeftModel(
114116
model,

src/peft/peft_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,15 @@ def _setup_prompt_encoder(self, adapter_name: str):
605605
# For reference refer to issue: https://github.com/huggingface/peft/issues/996
606606
deepspeed_distributed_tensor_shape = getattr(value, "ds_shape", None)
607607

608-
if value.shape[0] == self.base_model.config.vocab_size or (
608+
# Handle VLM case with separate text and vision configs
609+
if "text_config" in self.base_model.config:
610+
vocab_size = self.base_model.config.text_config.vocab_size
611+
else:
612+
vocab_size = self.base_model.config.vocab_size
613+
614+
if value.shape[0] == vocab_size or (
609615
deepspeed_distributed_tensor_shape is not None
610-
and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size
616+
and deepspeed_distributed_tensor_shape[0] == vocab_size
611617
):
612618
word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
613619
break

src/peft/utils/other.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ def check_adapter_name(adapter_name):
891891

892892

893893
def _prepare_prompt_learning_config(peft_config, model_config):
894+
# In case of VLM we focus on the language model portion of the model.
895+
if "text_config" in model_config:
896+
model_config = model_config["text_config"]
897+
894898
if peft_config.num_layers is None:
895899
if "num_hidden_layers" in model_config:
896900
num_layers = model_config["num_hidden_layers"]

tests/test_vision_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def test_past_kv(self):
6767
)
6868
processor = AutoProcessor.from_pretrained(model_id)
6969
raw_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
70-
inputs = processor(prompt, raw_image, return_tensors="pt")
70+
inputs = processor(text=prompt, images=raw_image, return_tensors="pt")
7171

7272
# get peft model
7373
peft_config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=20)
74-
model.language_model = get_peft_model(model.language_model, peft_config)
74+
model = get_peft_model(model, peft_config)
7575
# check that this does not raise
7676
model(**inputs, output_hidden_states=True)
7777

0 commit comments

Comments
 (0)