Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Pixtral fails when limit_mm_per_prompt not set #8382

Closed
1 task done
BabyChouSr opened this issue Sep 11, 2024 · 12 comments · Fixed by #8415
Closed
1 task done

[Bug]: Pixtral fails when limit_mm_per_prompt not set #8382

BabyChouSr opened this issue Sep 11, 2024 · 12 comments · Fixed by #8415
Labels
bug Something isn't working

Comments

@BabyChouSr
Copy link

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

🐛 Describe the bug

The below command does not work

CUDA_VISIBLE_DEVICES=3 vllm serve mistralai/Pixtral-12B-2409 --port 21010 --max_num_batched_tokens 16384 --trust-remote-code --gpu-memory-utilization 0.50 --tokenizer_mode mistral

It leads to this error:

Traceback (most recent call last):
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lmsys/vllm/vllm/entrypoints/openai/rpc/server.py", line 236, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
  File "/home/lmsys/vllm/vllm/entrypoints/openai/rpc/server.py", line 34, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
  File "/home/lmsys/vllm/vllm/engine/async_llm_engine.py", line 735, in from_engine_args
    engine = cls(
  File "/home/lmsys/vllm/vllm/engine/async_llm_engine.py", line 615, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/engine/async_llm_engine.py", line 835, in _init_engine
    return engine_class(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/engine/async_llm_engine.py", line 262, in __init__
    super().__init__(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/engine/llm_engine.py", line 338, in __init__
    self._initialize_kv_caches()
  File "/home/lmsys/vllm/vllm/engine/llm_engine.py", line 467, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
  File "/home/lmsys/vllm/vllm/executor/gpu_executor.py", line 114, in determine_num_available_blocks
    return self.driver_worker.determine_num_available_blocks()
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/worker/worker.py", line 223, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/worker/model_runner.py", line 1216, in profile_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/worker/model_runner.py", line 1543, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmsys/miniconda3/envs/vllm-source/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmsys/vllm/vllm/model_executor/models/pixtral.py", line 178, in forward
    inputs_embeds = merge_multimodal_embeddings(
  File "/home/lmsys/vllm/vllm/model_executor/models/pixtral.py", line 117, in merge_multimodal_embeddings
    assert (seq_len == N_txt +
AssertionError: seq_len 16640 should be equal to N_txt + N_img (256, 4096, 16384)

But the below works (following huggingface):

CUDA_VISIBLE_DEVICES=3 vllm serve mistralai/Pixtral-12B-2409 --port 21010 --max_num_batched_tokens 16384 --max-model-len 8192 --trust-remote-code --gpu-memory-utilization 0.50 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4'

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@BabyChouSr BabyChouSr added the bug Something isn't working label Sep 11, 2024
@BabyChouSr BabyChouSr changed the title [Bug]: Pixtral bug [Bug]: Pixtral fails when limit_mm_per_prompt not set Sep 11, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 12, 2024

@patrickvonplaten it looks like profile_run is creating more image placeholder tokens (16640) than what's expected by the model (16384) in this case. Perhaps you have to adjust how the dummy data is constructed.

@jdf-prog
Copy link

also encouter same error when processing this image:
[https://f2c628843e9892f5c7.gradio.live/file=/tmp/gradio/3036880890cf17b59a0cc838afc217dcd4d91ba5bc294ff42a99f6a2090f8bf2/equation.png]

What's really weird is that, once I resize it to (3844, 2408), then it will work.

Error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/dongfuj/WorkSpace/LMM-Engines/test_vllm_pixtral.py", line 34, in <module>
[rank0]:     outputs = llm.chat(messages, sampling_params=sampling_params)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 422, in chat
[rank0]:     return self.generate(
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/utils.py", line 1032, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 348, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 720, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1600, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 130, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 327, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1543, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/model_executor/models/pixtral.py", line 181, in forward
[rank0]:     inputs_embeds = merge_multimodal_embeddings(
[rank0]:   File "/home/dongfuj/.conda/envs/lmm-engines/lib/python3.10/site-packages/vllm/model_executor/models/pixtral.py", line 117, in merge_multimodal_embeddings
[rank0]:     assert (seq_len == N_txt +
[rank0]: AssertionError: seq_len 12 should be equal to N_txt + N_img (12, 4032, 0)

@jdf-prog
Copy link

I also tried to resize the image to (1024, 1024), still error. Seems there will be error if the image is more like a square shape?

@DarkLight1337
Copy link
Member

Can you try out #8399 and see if it fixes the issue which you've encountered?

@ywang96
Copy link
Member

ywang96 commented Sep 12, 2024

Hello @jdf-prog! Just to confirm, you were able to launch the server, but only this particular image ran into an issue, correct?

@ywang96
Copy link
Member

ywang96 commented Sep 12, 2024

Hmm, I was able to run inference with that image without any resizing

Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.71s/it, est. speed input: 1512.24 toks/s, output: 77.73 toks/s]
The document discusses a novel approach to policy optimization using preferences, specifically designed to tackle the challenges of fine-tuning language models with reinforcement learning (RL). The key innovation is the derivation of an optimal policy without an RL training loop, instead leveraging an analytical mapping from reward functions to optimal policies. This method transforms the reward functions into a loss function over policies, optimizing models of human preferences, particularly the Bradley-Terry model.

The text cites several prior works to derive an RL objective guarantee for the optimal policy based on a reward function, leading to a general solution involving a partition function. The complexity of estimating this function necessitates a reparameterization, which ultimately cancels out the partition function in practical models like the Bradley-Terry model. This reparameterization leads to the final expression of the optimal policy in terms of the optimal and reference policies alone, simplifying the preference model considerably. The main insight is to convert a loss function dependent on reward functions into one based directly on policies, improving the efficiency and feasibility of the optimization process.

@patrickvonplaten
Copy link
Contributor

Double checking this command:

CUDA_VISIBLE_DEVICES=3 vllm serve mistralai/Pixtral-12B-2409 --port 21010 --max_num_batched_tokens 16384 --trust-remote-code --gpu-memory-utilization 0.50 --tokenizer_mode mistral

BTW there is no need to pass --trust-remote-code here

@patrickvonplaten
Copy link
Contributor

Ah yes I see when passing --max_num_batched_tokens , but not --limit_mm_per_prompt 'image=4' then the profile_run throws an error I think that's because the image token create in pixtral is not great. Let me open a PR to fix it!

@jdf-prog
Copy link

Hello @jdf-prog! Just to confirm, you were able to launch the server, but only this particular image ran into an issue, correct?

Yes, only this particular image. The following is the code I encounter this error. It shall be simple to be reproduced.

from vllm import LLM
from vllm.sampling_params import SamplingParams

model_name = "mistralai/Pixtral-12B-2409"

sampling_params = SamplingParams(max_tokens=8192)

llm = LLM(model=model_name, tokenizer_mode="mistral", max_model_len=65536, limit_mm_per_prompt={"image":4})

prompt = "Can you derive Equation 6 from the image?"
image_url="https://f2c628843e9892f5c7.gradio.live/file=/tmp/gradio/3036880890cf17b59a0cc838afc217dcd4d91ba5bc294ff42a99f6a2090f8bf2/equation.png"

messages = [
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}]
    },
]

outputs = llm.chat(messages, sampling_params=sampling_params)

print(outputs[0].outputs[0].text)

@jdf-prog
Copy link

And the code it will work again after the resize:

from vllm import LLM
from vllm.sampling_params import SamplingParams
from PIL import Image
from io import BytesIO
import base64
import requests

def encode_image(image:Image.Image, image_format="PNG") -> str:
    im_file = BytesIO()
    image.save(im_file, format=image_format)
    im_bytes = im_file.getvalue()
    im_64 = base64.b64encode(im_bytes).decode("utf-8")
    return im_64
    
model_name = "mistralai/Pixtral-12B-2409"

sampling_params = SamplingParams(max_tokens=8192)

llm = LLM(model=model_name, tokenizer_mode="mistral", max_model_len=65536, limit_mm_per_prompt={"image":4})

prompt = "Can you derive Equation 6 from the image?"
image_url="https://f2c628843e9892f5c7.gradio.live/file=/tmp/gradio/3036880890cf17b59a0cc838afc217dcd4d91ba5bc294ff42a99f6a2090f8bf2/equation.png"

image = Image.open(BytesIO(requests.get(image_url).content))
image = image.resize((3844, 2408))
new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"

messages = [
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
    },
]

outputs = llm.chat(messages, sampling_params=sampling_params)

print(outputs[0].outputs[0].text)

@ywang96
Copy link
Member

ywang96 commented Sep 12, 2024

Hello @jdf-prog! Just to confirm, you were able to launch the server, but only this particular image ran into an issue, correct?

Yes, only this particular image. The following is the code I encounter this error. It shall be simple to be reproduced.

from vllm import LLM
from vllm.sampling_params import SamplingParams

model_name = "mistralai/Pixtral-12B-2409"

sampling_params = SamplingParams(max_tokens=8192)

llm = LLM(model=model_name, tokenizer_mode="mistral", max_model_len=65536, limit_mm_per_prompt={"image":4})

prompt = "Can you derive Equation 6 from the image?"
image_url="https://f2c628843e9892f5c7.gradio.live/file=/tmp/gradio/3036880890cf17b59a0cc838afc217dcd4d91ba5bc294ff42a99f6a2090f8bf2/equation.png"

messages = [
    {
        "role": "user",
        "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}]
    },
]

outputs = llm.chat(messages, sampling_params=sampling_params)

print(outputs[0].outputs[0].text)

@jdf-prog I'm pretty certain this is due to the fact that chunked prefill is working pretty flakily with VLMs. By default, when the max-model-len is bigger than 32768, chunked prefill will be turned on by default with max-num-batched-tokens set to a fixed number (4096 for VLMs), and this is something we should definitely address in the near future. Perhaps, we should encouraging explicitly turning off chunked prefill for VLMs for now.

In the mean time, can you modify your model initialization similar to what's in examples/offline_inference_pixtral.py?

    model_name = "mistralai/Pixtral-12B-2409"
    max_img_per_msg = 5
    max_tokens_per_img = 4096

    sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
    llm = LLM(
        model=model_name,
        tokenizer_mode="mistral",
        limit_mm_per_prompt={"image": max_img_per_msg},
        max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
    )

@jdf-prog
Copy link

In the mean time, can you modify your model initialization similar to what's in examples/offline_inference_pixtral.py?

    model_name = "mistralai/Pixtral-12B-2409"
    max_img_per_msg = 5
    max_tokens_per_img = 4096

    sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
    llm = LLM(
        model=model_name,
        tokenizer_mode="mistral",
        limit_mm_per_prompt={"image": max_img_per_msg},
        max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
    )

Thanks, I tried this and it seems to work. Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants