Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Qwen2-VL with pytorch backend #2449

Merged
merged 10 commits into from
Sep 23, 2024
Merged

support Qwen2-VL with pytorch backend #2449

merged 10 commits into from
Sep 23, 2024

Conversation

irexyc
Copy link
Collaborator

@irexyc irexyc commented Sep 11, 2024

Motivation

Support https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct with pytorch backend

Currently, it only support image input and video input support should wait after refactoring of vision model.

#2436
#2415
#2411

@QwertyJack
Copy link
Contributor

环境:lmdeploy@9ee6abe,py311,cuda121,V100-32GB

运行 qwen2 官方提供的 Qwen2-VL-7B 报错:

+ exec lmdeploy serve api_server /models/Qwen2-VL-7B-Instruct-AWQ --model-name qwen2-vl-7b --model-format awq --log-level INFO --quant-policy 8 --session-len 16384 --enable-prefix-caching --server-port 8000 --cache-max-entry-count 0.2
2024-09-12 03:45:47,426 - lmdeploy - WARNING - Fallback to pytorch engine because `/models/Qwen2-VL-7B-Instruct-AWQ` not supported by turbomind engine.
2024-09-12 03:45:49,220 - lmdeploy - INFO - matching vision model: Qwen2VLModel
2024-09-12 03:45:54,473 - lmdeploy - INFO - input backend=pytorch, backend_config=PytorchEngineConfig(tp=1, session_len=16384, max_batch_size=128, cache_max_entry_count=0.2, prefill_interval=16,
block_size=64, num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=8192, thread_safe=False, enable_prefix_caching=True, device_type='cuda', eager_mode=False, custom_module_ma
p=None, download_dir=None, revision=None)
2024-09-12 03:45:54,473 - lmdeploy - INFO - input chat_template_config=None
2024-09-12 03:45:54,479 - lmdeploy - INFO - updated chat_template_onfig=ChatTemplateConfig(model_name='qwen', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None,
eoa=None, separator=None, capability=None, stop_words=None)
2024-09-12 03:45:54,513 - lmdeploy - INFO - Checking environment for PyTorch Engine.
2024-09-12 03:45:54,689 - lmdeploy - WARNING - Engine has not been tested on triton>2.2.0.
2024-09-12 03:45:55,316 - lmdeploy - INFO - Checking model.
2024-09-12 03:45:55,317 - lmdeploy - WARNING - LMDeploy requires transformers version: [4.33.0 ~ 4.44.1], but found version: 4.45.0.dev0
2024-09-12 03:45:56,607 - lmdeploy - INFO - build model.
2024-09-12 03:45:56,980 - lmdeploy - INFO - loading weights.
2024-09-12 03:45:56,987 - lmdeploy - INFO - loading weights - "model-00001-of-00002.safetensors"
2024-09-12 03:45:57,875 - lmdeploy - INFO - loading weights - "model-00002-of-00002.safetensors"
2024-09-12 03:45:59,028 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(max_batches=128, block_size=64, num_cpu_blocks=1170, num_gpu_blocks=80, window_size=-1, cache_max_entry_count
=0.2, max_prefill_token_num=8192, enable_prefix_caching=True)
2024-09-12 03:46:03,443 - lmdeploy - INFO - updated backend_config=PytorchEngineConfig(tp=1, session_len=16384, max_batch_size=128, cache_max_entry_count=0.2, prefill_interval=16, block_size=64,
num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=8192, thread_safe=False, enable_prefix_caching=True, device_type='cuda', eager_mode=False, custom_module_map=None, downloa
d_dir=None, revision=None)
HINT:    Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
INFO:     Started server process [25717]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     127.0.0.1:55312 - "POST /v1/chat/completions HTTP/1.1" 200 OK
2024-09-12 03:46:47,224 - lmdeploy - INFO - start ImageEncoder._forward_loop
2024-09-12 03:46:47,224 - lmdeploy - INFO - ImageEncoder received 1 images, left 1 images.
2024-09-12 03:46:47,224 - lmdeploy - INFO - ImageEncoder process 1 images, left 0 images.
2024-09-12 03:46:47,843 - lmdeploy - INFO - ImageEncoder forward 1 images, cost 0.619s
2024-09-12 03:46:47,844 - lmdeploy - INFO - ImageEncoder done 1 images, left 0 images.
2024-09-12 03:46:47,847 - lmdeploy - INFO - prompt='<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><IMAGE_TOKEN><|vision_end|>What is this? Explain it in Chinese.<|im_end|>\n<|
im_start|>assistant\n', gen_config=GenerationConfig(n=1, max_new_tokens=40000, do_sample=True, top_p=1.0, top_k=40, min_p=0.0, temperature=0.3, repetition_penalty=1.0, ignore_eos=False, random_seed=14616964279972399396, sto
p_words=None, bad_words=None, stop_token_ids=[151645], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, logprobs=None, response_format=None, logits_processors=None), prompt_token_id=[151644, 8948, 198, 261
0, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 151653, 3838, 374, 419, 30, 81917, 432, 304, 8453, 13, 151645, 198, 151644, 77091, 198], adapter_name=None.
2024-09-12 03:46:47,847 - lmdeploy - INFO - session_id=1, history_tokens=0, input_tokens=316, max_new_tokens=40000, seq_start=True, seq_end=True, step=0, prep=True
2024-09-12 03:46:47,847 - lmdeploy - ERROR - Truncate max_new_tokens to 16068
2024-09-12 03:46:47,859 - lmdeploy - ERROR - Engine loop failed with error: 'int' object has no attribute 'to'
Traceback (most recent call last):
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 957, in async_loop
    await self._async_loop()
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 947, in _async_loop
    await __step(True)
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 933, in __step
    raise e
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 925, in __step
    raise out
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 869, in _async_loop_background
    await self._async_step_background(
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 748, in _async_step_background
    output = await self._async_model_forward(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/utils.py", line 237, in __tmp
    return (await func(*args, **kwargs))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 646, in _async_model_forward
    ret = await __forward(inputs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/engine.py", line 624, in __forward
    return await self.model_agent.async_forward(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 332, in async_forward
    output = self._forward_impl(inputs,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 299, in _forward_impl
    output = model_forward(
             ^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 140, in model_forward
    inputs = inputs.to_device('cuda')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/model_inputs.py", line 227, in to_device
    v = v.to_device(device)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/model_inputs.py", line 83, in to_device
    v = [x.to(device) for x in v]
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/model_inputs.py", line 83, in <listcomp>
    v = [x.to(device) for x in v]
         ^^^^
AttributeError: 'int' object has no attribute 'to'
ERROR:    Traceback (most recent call last):
  File "/home/ma/lmd/lib/python3.11/asyncio/tasks.py", line 500, in wait_for
    return fut.result()
           ^^^^^^^^^^^^                                                                                                                                                                              File "/home/ma/lmd/lib/python3.11/asyncio/queues.py", line 158, in get
    await getter                                                                                                                                                                                   asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ma/lmd/lib/python3.11/site-packages/lmdeploy/pytorch/engine/request.py", line 171, in __no_threadsafe_get
    return await asyncio.wait_for(self.resp_que.get(), timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/asyncio/tasks.py", line 502, in wait_for
    raise exceptions.TimeoutError() from exc
TimeoutError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ma/lmd/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ma/lmd/lib/python3.11/asyncio/base_events.py", line 640, in run_until_complete
...

@irexyc
Copy link
Collaborator Author

irexyc commented Sep 12, 2024

@QwertyJack

Currently, the pytorch backend does not support loading awq models, which is still being worked on.

Only support Qwen2-VL-2B-Instruct by now.

@QwertyJack
Copy link
Contributor

okay 辛苦

@lvhan028 lvhan028 added the enhancement New feature or request label Sep 14, 2024
@@ -128,6 +133,14 @@ def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
self.input_buffers['inputs_embeds'] = inputs_embeds.new_zeros(
1, max_num_tokens, emb_size)
self.input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds
if mrope_position_ids is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mrope_position_ids always exist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For qwen2-vl model, it will always exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For qwen2-vl model, it will always exist.

If there is no image input, it can be None. Just raised an error if you use lmdeploy chat command. However, pipline worked even if there is no image input.

@ldknight
Copy link

Hi, thanks for the exciting work you did, but I encountered this problem while using: 'Qwen2VLForConditionalGeneration' object has no attribute 'lm_head'

Copy link
Collaborator

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got following error:

cannot import name 'Qwen2VLForConditionalGeneration' from 'transformers'

However, I used the required version of transformers in config.json.

@chenzhengda
Copy link

Are there any plans regarding the Turbomind Engine?

@irexyc
Copy link
Collaborator Author

irexyc commented Sep 18, 2024

@ldknight @AllentDan
please install transformers by pip install git+https://github.com/huggingface/transformers.git

@chenzhengda
Yes, we have this plan and will support it later

@ldknight
Copy link

@ldknight @AllentDan please install transformers by pip install git+https://github.com/huggingface/transformers.git

@chenzhengda Yes, we have this plan and will support it later

Thanks for your reply. I have tried to reinstall transformers, but the problem still occurs. Currently I use torch==2.4.0, transformers==4.45.0.dev0, accelerate==0.34.0, and qwen-vl-utils==0.0.4.

@PiyushSawarkar
Copy link

PiyushSawarkar commented Sep 19, 2024

@ldknight @AllentDan please install transformers by pip install git+https://github.com/huggingface/transformers.git
@chenzhengda Yes, we have this plan and will support it later

Thanks for your reply. I have tried to reinstall transformers, but the problem still occurs. Currently I use torch==2.4.0, transformers==4.45.0.dev0, accelerate==0.34.0, and qwen-vl-utils==0.0.4.

Hi, @ldknight, so I was able to perform inference with Qwen/Qwen2-VL-7B-Instruct model using @irexyc's git repo. So would like to share some steps I followed:

  • Make sure you clone the qwen2-vl branch of the above repo.
  • After cloning,
    cd lmdeploy
  • docker build this:
    docker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile
    This creates the following docker image:
    openmmlab/lmdeploy, tag: qwen2vl, with size 15.8GB
  • Now launch the docker container
    docker run --gpus all --net host --shm-size 16g -v $(pwd):/opt/lmdeploy -it --name QWEN2VL-CONT openmmlab/lmdeploy:qwen2vl /bin/bash
  • Once inside the above container, then run the following commands
cd /opt/lmdeploy
mkdir -p build && cd build
bash ../generate.sh make
make -j$(nproc) && make install
cd ..
pip install -e .
  • Test using this inference code:
from lmdeploy import pipeline
from lmdeploy.vl import load_image

pipe = pipeline('Qwen/Qwen2-VL-7B-Instruct')

image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe((f'describe this image', image))
print(response)

@grimoire
Copy link
Collaborator

chatglm3 failed with transformers=4.45.0.dev0

  File "xxx/transformers/src/transformers/tokenization_utils_base.py", line 3509, in pad
    encoded_inputs = self._pad(
TypeError: _pad() got an unexpected keyword argument 'padding_side'

@AllentDan
Copy link
Collaborator

AllentDan commented Sep 20, 2024

lmdeploy chat Qwen/Qwen2-VL-2B-Instruct failed. Besides, vl pipeline failed with error 'Qwen2VLForConditionalGeneration' object has no attribute 'lm_head'. The transformers (4.45.0.dev0) and accelerate packages are already the latest.

Copy link
Collaborator

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for transformers version, shall we add restriction inside qwen2-vl codes for users?

@lvhan028
Copy link
Collaborator

may also update README, support_models.md

@lvhan028 lvhan028 merged commit 254d90a into InternLM:main Sep 23, 2024
5 checks passed
@PedroRASB
Copy link

Hi, I see that the list of supported models does not include Qwen2-VL-72B (https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct). Is it being added too?

@QwertyJack
Copy link
Contributor

Hi, I see that the list of supported models does not include Qwen2-VL-72B (https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct). Is it being added too?

+1

@irexyc
Copy link
Collaborator Author

irexyc commented Sep 26, 2024

@PedroRASB @QwertyJack

I compared the config of Qwen2-VL-72B with Qwen2-VL-7B, only some layers dim input/output dimensions are different. So I think the current code shoud support 72B model. I'll check it, but it will take some time to download the model.

@PiyushSawarkar
Copy link

PiyushSawarkar commented Sep 26, 2024

@PedroRASB @QwertyJack

I compared the config of Qwen2-VL-72B with Qwen2-VL-7B, only some layers dim input/output dimensions are different. So I think the current code shoud support 72B model. I'll check it, but it will take some time to download the model.

Hi, @irexyc just to confirm, the current code indeed supports 72B param model, I am able to perform inference with this Model (following same steps mentioned here)

@PiyushSawarkar
Copy link

PiyushSawarkar commented Sep 26, 2024

@PedroRASB @QwertyJack
I compared the config of Qwen2-VL-72B with Qwen2-VL-7B, only some layers dim input/output dimensions are different. So I think the current code shoud support 72B model. I'll check it, but it will take some time to download the model.

Hi, @irexyc just to confirm, the current code indeed supports 72B param model, I am able to perform inference with this Model (following same steps mentioned here)

An issue though, I am encountering, where offline inference works fine, but during online inference using the lmdeploy server (api_server), I consistently run into a CUDA out-of-memory error, even though the model is distributed across 8 H100 GPUs.
My exact command
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 lmdeploy serve api_server Qwen/Qwen2-VL-72B-Instruct --tp 8 --server-port 27770
I've tried both with and without the --enable-prefix-caching flag, but the problem persists. I also experimented with adjusting the --cache-max-entry-count parameter, but I'm still encountering the same issue.

@irexyc
Copy link
Collaborator Author

irexyc commented Sep 27, 2024

@PiyushSawarkar

What --cache-max-entry-count value are you using and how many cuda memory left after loading the model? The pytorch backend seems to require more runtime cuda memory compared to turbomind backend.

You can use Qwen/Qwen2-VL-72B-Instruct-AWQ model to reduce the weights memory and leave more for kvcache . I just tried Qwen/Qwen2-VL-2B-Instruct-AWQ model and it worked well. Since the quantization_config is same, I think the 72b-awq model should also work.

@heyongxin233
Copy link

@PiyushSawarkar Hello, I tried to deploy Qwen2-VL 7B on four 4090 GPUs (24GB) using the code below, but it failed.
pipe = pipeline('Qwen/Qwen2-VL-7B-Instruct', backend_config=PytorchEngineConfig(session_len=1024,tp=4))
image
However, when the tp (Tensor Parallelism) set to 1, it loads successfully.

@irexyc
Copy link
Collaborator Author

irexyc commented Sep 29, 2024

@heyongxin233

When you using pytorch backend with tp > 1, you have to put your code in __main__ like:

if __name__ == '__main__':
  # init pipeline here

@henry16lin
Copy link

@PiyushSawarkar

What --cache-max-entry-count value are you using and how many cuda memory left after loading the model? The pytorch backend seems to require more runtime cuda memory compared to turbomind backend.

You can use Qwen/Qwen2-VL-72B-Instruct-AWQ model to reduce the weights memory and leave more for kvcache . I just tried Qwen/Qwen2-VL-2B-Instruct-AWQ model and it worked well. Since the quantization_config is same, I think the 72b-awq model should also work.

@irexyc
Hi, I tried Qwen/Qwen2-VL-2B-Instruct-AWQ but ran into version conflict problem...
In my device, I install autoawq with version autoawq 0.2.4+cu122 which require transformers<=4.38.2,>=4.35.0
If I use transformers 4.38.2, it will get error: No module named 'transformers.models.mllama', which can be solved with last version of transformers (4.45 or 4.46)

I can successfully run Qwen/Qwen2-VL-2B-Instruct with transformer==4.46.1 but it uses a little bit too much CUDA memory (~6G) in my scenario (I have set cache_max_entry_count=0.01), that's the reason I want to try Qwen/Qwen2-VL-2B-Instruct-AWQ
Do you have any recommendation? Thank you 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.