-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Closed
Labels
Description
After updating to v4.46, vLLM's CI fails to run Chameleon on HF. Upon investigation, I found that the model cannot even be run using HF's example script.
System Info
Python 3.9, Transformers v4.46.0
Who can help?
@laurentd-lunit since you authored #33608
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Run the example script in the documentation for ChameleonForConditionalGeneration:
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
import torch
import requests
from PIL import Image
# I changed this line with `device_map="auto"` so it can load into my GPUs. It should not affect the result
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", device_map="auto", torch_dtype=torch.bfloat16)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
processor.batch_decode(generated_ids, skip_special_tokens=True)[0]Full output:
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00, 2.10s/it]
Some kwargs in processor config are unused and will not have any effect: image_seq_length, image_token.
/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:590: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.7` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
warnings.warn(
/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:595: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
warnings.warn(
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Traceback (most recent call last):
File "/home/cyrus/vllm/run_chameleon.py", line 15, in <module>
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/utils.py", line 2215, in generate
result = self._sample(
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/utils.py", line 3206, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/models/chameleon/modeling_chameleon.py", line 1589, in forward
outputs = self.model(
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/models/chameleon/modeling_chameleon.py", line 1293, in forward
raise ValueError(
ValueError: Image features and image tokens do not match: tokens: 2048, features 2
The error was added in #33608. Perhaps that PR is faulty.
Expected behavior
The example script should be run successfully.