Skip to content

Unable to run Chameleon model since v4.46 #34379

@DarkLight1337

Description

@DarkLight1337

After updating to v4.46, vLLM's CI fails to run Chameleon on HF. Upon investigation, I found that the model cannot even be run using HF's example script.

System Info

Python 3.9, Transformers v4.46.0

Who can help?

@laurentd-lunit since you authored #33608

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the example script in the documentation for ChameleonForConditionalGeneration:

from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
import torch
import requests
from PIL import Image

# I changed this line with `device_map="auto"` so it can load into my GPUs. It should not affect the result
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", device_map="auto", torch_dtype=torch.bfloat16)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")

prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)

inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)

generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

Full output:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  2.10s/it]
Some kwargs in processor config are unused and will not have any effect: image_seq_length, image_token. 
/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:590: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.7` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
  warnings.warn(
/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:595: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Traceback (most recent call last):
  File "/home/cyrus/vllm/run_chameleon.py", line 15, in <module>
    generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/generation/utils.py", line 3206, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/models/chameleon/modeling_chameleon.py", line 1589, in forward
    outputs = self.model(
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cyrus/miniconda3/envs/vllm/lib/python3.9/site-packages/transformers/models/chameleon/modeling_chameleon.py", line 1293, in forward
    raise ValueError(
ValueError: Image features and image tokens do not match: tokens: 2048, features 2

The error was added in #33608. Perhaps that PR is faulty.

Expected behavior

The example script should be run successfully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions