Skip to content

[Bugfix][Multi Modal] Fix incorrect Molmo token processing#26873

Merged
DarkLight1337 merged 1 commit intovllm-project:mainfrom
sangho-vision:fix_molmo_chat
Oct 15, 2025
Merged

[Bugfix][Multi Modal] Fix incorrect Molmo token processing#26873
DarkLight1337 merged 1 commit intovllm-project:mainfrom
sangho-vision:fix_molmo_chat

Conversation

@sangho-vision
Copy link
Copy Markdown
Contributor

@sangho-vision sangho-vision commented Oct 15, 2025

Purpose

When serving a Molmo model online using chat completion, the vLLM code first applies Molmo’s chat template to the input text and tokenizes it. It then calls the custom _apply_hf_processor_tokens_only method in MolmoMultiModalProcessor for further processing.
However, _apply_hf_processor_tokens_only internally calls the get_tokens_input of the Molmo's HF processor, which applies the chat template once again, resulting in double templating.

This behavior can be verified by running the following code:

from vllm import LLM
from vllm.sampling_params import SamplingParams
from transformers import AutoProcessor

model = LLM(
    model="allenai/Molmo-7B-D-0924",
    tensor_parallel_size=torch.cuda.device_count(),
    trust_remote_code=True,
    dtype='bfloat16',
    gpu_memory_utilization=0.95,
)

processor = AutoProcessor.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    dtype="auto",
    device_map="auto",
)

sampling_params = SamplingParams(max_tokens=448, temperature=0)

image_url = "https://www.visitscotland.com/binaries/content/gallery/visitscotland/cms-images/2022/06/24/clashnessie-bay-car-road"

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": "Describe the image."
            },
            {
                "type": "image_url",
                "image_url": {
                    "url": image_url
                }
            },
        ],
    },
]

outputs = model.chat(messages, sampling_params=sampling_params)
prompt = processor.tokenizer.decode(outputs[0].prompt_token_ids, skip_special_tokens=True)
print(prompt)

This prints:

 User: User: Describe the image. Assistant: Assistant:

The double "User:" and "Assistant:" indicate that the chat template is applied twice.

This PR fixes this double templating issue.

Changes Made

Make sure that _apply_hf_processor_tokens_only uses the "none" message format instead of the model's default configuration:

        # The chat template is already applied to the prompt tokens
        # Use message_format="none" to avoid applying it again
        # Prepend an empty space if `always_start_with_space` is True
        tokens = processor.processor.get_tokens_input(  # type: ignore
            self.info.get_tokenizer().decode(prompt_tokens),
            message_format="none",
            always_start_with_space=processor.always_start_with_space,
        )

Test Plan

Run the same code snippet above.

Test Result

 User: Describe the image. Assistant:

The double templating is resolved.

Signed-off-by: sanghol <sanghol@allenai.org>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug where the Molmo chat template was being applied twice, leading to incorrect model inputs. The fix is simple and effective. My review includes a suggestion to optimize the token processing logic to avoid an inefficient decode-then-encode cycle, which can improve performance for long prompts.

Comment on lines +1267 to 1274
# The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True
tokens = processor.processor.get_tokens_input( # type: ignore
self.info.get_tokenizer().decode(prompt_tokens),
message_format=processor.message_format,
message_format="none",
always_start_with_space=processor.always_start_with_space,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this correctly fixes the double-templating issue, calling get_tokens_input here incurs a decode-then-encode cycle for the entire prompt on every request. This can be inefficient for long prompts.

Since the main purpose of this call (with message_format="none") is to enforce the always_start_with_space logic, we can optimize this by inlining a more efficient version of that logic. The suggestion below avoids re-encoding the prompt if it already starts with a space, which can provide a significant performance improvement.

Suggested change
# The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True
tokens = processor.processor.get_tokens_input( # type: ignore
self.info.get_tokenizer().decode(prompt_tokens),
message_format=processor.message_format,
message_format="none",
always_start_with_space=processor.always_start_with_space,
)
tokenizer = self.info.get_tokenizer()
# The chat template is already applied. The logic below is an
# optimized reimplementation of `processor.get_tokens_input`
# with `message_format="none"`. It avoids a decode-encode cycle
# if the prompt already starts with a space, improving performance.
if processor.always_start_with_space:
decoded_prompt = tokenizer.decode(prompt_tokens)
if not decoded_prompt.startswith(" "):
tokens = tokenizer.encode(" " + decoded_prompt,
add_special_tokens=False)
else:
tokens = prompt_tokens
else:
tokens = prompt_tokens

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 15, 2025 03:08
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 15, 2025
@DarkLight1337 DarkLight1337 merged commit 8865da1 into vllm-project:main Oct 15, 2025
56 checks passed
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…ect#26873)

Signed-off-by: sanghol <sanghol@allenai.org>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect#26873)

Signed-off-by: sanghol <sanghol@allenai.org>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
@DarkLight1337 DarkLight1337 mentioned this pull request Oct 26, 2025
5 tasks
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ect#26873)

Signed-off-by: sanghol <sanghol@allenai.org>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
@sangho-vision sangho-vision deleted the fix_molmo_chat branch December 2, 2025 01:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants