[Bugfix][Multi Modal] Fix incorrect Molmo token processing#26873
[Bugfix][Multi Modal] Fix incorrect Molmo token processing#26873DarkLight1337 merged 1 commit intovllm-project:mainfrom sangho-vision:fix_molmo_chat
Conversation
Signed-off-by: sanghol <sanghol@allenai.org>
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a bug where the Molmo chat template was being applied twice, leading to incorrect model inputs. The fix is simple and effective. My review includes a suggestion to optimize the token processing logic to avoid an inefficient decode-then-encode cycle, which can improve performance for long prompts.
| # The chat template is already applied to the prompt tokens | ||
| # Use message_format="none" to avoid applying it again | ||
| # Prepend an empty space if `always_start_with_space` is True | ||
| tokens = processor.processor.get_tokens_input( # type: ignore | ||
| self.info.get_tokenizer().decode(prompt_tokens), | ||
| message_format=processor.message_format, | ||
| message_format="none", | ||
| always_start_with_space=processor.always_start_with_space, | ||
| ) |
There was a problem hiding this comment.
While this correctly fixes the double-templating issue, calling get_tokens_input here incurs a decode-then-encode cycle for the entire prompt on every request. This can be inefficient for long prompts.
Since the main purpose of this call (with message_format="none") is to enforce the always_start_with_space logic, we can optimize this by inlining a more efficient version of that logic. The suggestion below avoids re-encoding the prompt if it already starts with a space, which can provide a significant performance improvement.
| # The chat template is already applied to the prompt tokens | |
| # Use message_format="none" to avoid applying it again | |
| # Prepend an empty space if `always_start_with_space` is True | |
| tokens = processor.processor.get_tokens_input( # type: ignore | |
| self.info.get_tokenizer().decode(prompt_tokens), | |
| message_format=processor.message_format, | |
| message_format="none", | |
| always_start_with_space=processor.always_start_with_space, | |
| ) | |
| tokenizer = self.info.get_tokenizer() | |
| # The chat template is already applied. The logic below is an | |
| # optimized reimplementation of `processor.get_tokens_input` | |
| # with `message_format="none"`. It avoids a decode-encode cycle | |
| # if the prompt already starts with a space, improving performance. | |
| if processor.always_start_with_space: | |
| decoded_prompt = tokenizer.decode(prompt_tokens) | |
| if not decoded_prompt.startswith(" "): | |
| tokens = tokenizer.encode(" " + decoded_prompt, | |
| add_special_tokens=False) | |
| else: | |
| tokens = prompt_tokens | |
| else: | |
| tokens = prompt_tokens |
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org> Signed-off-by: bbartels <benjamin@bartels.dev>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org>
…ect#26873) Signed-off-by: sanghol <sanghol@allenai.org>
Purpose
When serving a Molmo model online using chat completion, the vLLM code first applies Molmo’s chat template to the input text and tokenizes it. It then calls the custom
_apply_hf_processor_tokens_onlymethod inMolmoMultiModalProcessorfor further processing.However,
_apply_hf_processor_tokens_onlyinternally calls theget_tokens_inputof the Molmo's HF processor, which applies the chat template once again, resulting in double templating.This behavior can be verified by running the following code:
This prints:
The double "User:" and "Assistant:" indicate that the chat template is applied twice.
This PR fixes this double templating issue.
Changes Made
Make sure that
_apply_hf_processor_tokens_onlyuses the"none"message format instead of the model's default configuration:Test Plan
Run the same code snippet above.
Test Result
The double templating is resolved.