diff --git a/components/src/dynamo/vllm/multimodal_handlers/preprocessor_handler.py b/components/src/dynamo/vllm/multimodal_handlers/preprocessor_handler.py index a9bb7ef8e1d..3fb4a0df70f 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/preprocessor_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/preprocessor_handler.py @@ -23,6 +23,7 @@ MultiModalRequest, MyRequestOutput, ProcessMixIn, + extract_user_text, vLLMMultimodalRequest, ) @@ -156,10 +157,7 @@ async def generate(self, raw_request: MultiModalRequest, context): raise ValueError("prompt_template must contain '' placeholder") # Safely extract user text - try: - user_text = raw_request.messages[0].content[0].text - except (IndexError, AttributeError) as e: - raise ValueError(f"Invalid message structure: {e}") + user_text = extract_user_text(raw_request.messages) prompt = template.replace("", user_text) diff --git a/components/src/dynamo/vllm/multimodal_utils/__init__.py b/components/src/dynamo/vllm/multimodal_utils/__init__.py index 7d1944cc4f9..17523e1a44d 100644 --- a/components/src/dynamo/vllm/multimodal_utils/__init__.py +++ b/components/src/dynamo/vllm/multimodal_utils/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text from dynamo.vllm.multimodal_utils.chat_processor import ( ChatProcessor, CompletionsProcessor, @@ -31,6 +32,7 @@ "CompletionsProcessor", "ProcessMixIn", "encode_image_embeddings", + "extract_user_text", "get_encoder_components", "get_http_client", "ImageLoader", diff --git a/components/src/dynamo/vllm/multimodal_utils/chat_message_utils.py b/components/src/dynamo/vllm/multimodal_utils/chat_message_utils.py new file mode 100644 index 00000000000..a3352b85182 --- /dev/null +++ b/components/src/dynamo/vllm/multimodal_utils/chat_message_utils.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for processing chat messages.""" + +from typing import List + +from dynamo.vllm.multimodal_utils.protocol import ChatMessage + + +def extract_user_text(messages: List[ChatMessage]) -> str: + """Extract and concatenate text content from user messages.""" + + # This function finds all text content items from "user" role messages, + # and concatenates them. For multi-turn conversation, it adds a newline + # between each turn. This is not a perfect solution as we encode multi-turn + # conversation as a single turn. However, multi-turn conversation in a + # single request is not well defined in the spec. + + # TODO: Revisit this later when adding multi-turn conversation support. + user_texts = [] + for message in messages: + if message.role == "user": + # Collect all text content items from this user message + text_parts = [] + for item in message.content: + if item.type == "text" and item.text: + text_parts.append(item.text) + # If this user message has text content, join it and add to user_texts + if text_parts: + user_texts.append("".join(text_parts)) + + if not user_texts: + raise ValueError("No text content found in user messages") + + # Join all user turns with newline separator + return "\n".join(user_texts) diff --git a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py index e70a794040a..d3e1da4b5e8 100644 --- a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py +++ b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py @@ -59,12 +59,6 @@ class ProcessMixIn(ProcessMixInRequired): Mixin for pre and post processing for vLLM """ - engine_args: AsyncEngineArgs - chat_processor: "ChatProcessor | None" - completions_processor: "CompletionsProcessor | None" - model_config: ModelConfig - default_sampling_params: SamplingParams - def __init__(self): pass diff --git a/components/src/dynamo/vllm/multimodal_utils/protocol.py b/components/src/dynamo/vllm/multimodal_utils/protocol.py index 49a2fe2e078..b619f5f1ebd 100644 --- a/components/src/dynamo/vllm/multimodal_utils/protocol.py +++ b/components/src/dynamo/vllm/multimodal_utils/protocol.py @@ -18,7 +18,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union import msgspec -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic_core import core_schema from typing_extensions import NotRequired from vllm.inputs.data import TokensPrompt @@ -89,9 +89,13 @@ def parse_sampling_params(cls, v: Any) -> SamplingParams: return SamplingParams(**v) return v + @field_serializer("sampling_params") + def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]: + """Serialize SamplingParams using msgspec and return as dict.""" + return json.loads(msgspec.json.encode(value)) + model_config = ConfigDict( arbitrary_types_allowed=True, - json_encoders={SamplingParams: lambda v: json.loads(msgspec.json.encode(v))}, ) diff --git a/components/src/dynamo/vllm/tests/test_chat_message_utils.py b/components/src/dynamo/vllm/tests/test_chat_message_utils.py new file mode 100644 index 00000000000..757eaee7ecf --- /dev/null +++ b/components/src/dynamo/vllm/tests/test_chat_message_utils.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for chat message utility functions.""" + +import pytest + +from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text +from dynamo.vllm.multimodal_utils.protocol import ( + ChatMessage, + ImageContent, + ImageURLDetail, + TextContent, +) + +pytestmark = [ + pytest.mark.unit, + pytest.mark.vllm, + pytest.mark.gpu_0, + pytest.mark.pre_merge, +] + + +def test_extract_user_text_single_message(): + """Test extracting text from a single user message with one text content.""" + messages = [ + ChatMessage( + role="user", content=[TextContent(type="text", text="Hello, world!")] + ) + ] + + result = extract_user_text(messages) + assert result == "Hello, world!" + + +def test_extract_user_text_multiple_text_parts(): + """Test extracting text from a user message with multiple text content items.""" + messages = [ + ChatMessage( + role="user", + content=[ + TextContent(type="text", text="First part "), + ImageContent( + type="image_url", + image_url=ImageURLDetail(url="http://example.com/image.jpg"), + ), + TextContent(type="text", text="second part"), + ], + ) + ] + + result = extract_user_text(messages) + assert result == "First part second part" + + +def test_extract_user_text_multi_turn(): + """Test extracting text from multi-turn conversation.""" + messages = [ + ChatMessage( + role="user", content=[TextContent(type="text", text="First question")] + ), + ChatMessage( + role="assistant", content=[TextContent(type="text", text="First answer")] + ), + ChatMessage( + role="user", content=[TextContent(type="text", text="Second question")] + ), + ] + + result = extract_user_text(messages) + assert result == "First question\nSecond question" + + +def test_extract_user_text_only_images(): + """Test that ValueError is raised when messages contain only images.""" + messages = [ + ChatMessage( + role="user", + content=[ + ImageContent( + type="image_url", + image_url=ImageURLDetail(url="http://example.com/image.jpg"), + ) + ], + ) + ] + + with pytest.raises(ValueError, match="No text content found in user messages"): + extract_user_text(messages) + + +def test_extract_user_text_empty_messages(): + """Test that ValueError is raised when messages list is empty.""" + messages: list[ChatMessage] = [] + + with pytest.raises(ValueError, match="No text content found in user messages"): + extract_user_text(messages) + + +def test_extract_user_text_no_user_messages(): + """Test that ValueError is raised when there are no user role messages.""" + messages = [ + ChatMessage( + role="assistant", + content=[TextContent(type="text", text="Just an assistant message")], + ) + ] + + with pytest.raises(ValueError, match="No text content found in user messages"): + extract_user_text(messages) + + +def test_extract_user_text_mixed_roles(): + """Test extracting text only from user messages, ignoring other roles.""" + messages = [ + ChatMessage( + role="system", content=[TextContent(type="text", text="System prompt")] + ), + ChatMessage( + role="user", content=[TextContent(type="text", text="User message 1")] + ), + ChatMessage( + role="assistant", + content=[TextContent(type="text", text="Assistant response")], + ), + ChatMessage( + role="user", content=[TextContent(type="text", text="User message 2")] + ), + ] + + result = extract_user_text(messages) + assert result == "User message 1\nUser message 2" + + +def test_extract_user_text_empty_text_content(): + """Test that empty text content items are ignored.""" + messages = [ + ChatMessage( + role="user", + content=[ + TextContent(type="text", text=""), + TextContent(type="text", text="Valid text"), + TextContent(type="text", text=""), + ], + ) + ] + + result = extract_user_text(messages) + assert result == "Valid text" diff --git a/examples/multimodal/components/processor.py b/examples/multimodal/components/processor.py index 2f61cf7adb2..1da06c2639c 100644 --- a/examples/multimodal/components/processor.py +++ b/examples/multimodal/components/processor.py @@ -27,6 +27,7 @@ # To import example local module sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) from utils.args import Config, base_parse_args, parse_endpoint +from utils.chat_message_utils import extract_user_text from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.protocol import ( MultiModalInput, @@ -203,15 +204,7 @@ async def generate(self, raw_request: MultiModalRequest): if "" not in template: raise ValueError("prompt_template must contain '' placeholder") - # Safely extract user text - find the text content item - user_text = None - for message in raw_request.messages: - for item in message.content: - if item.type == "text": - user_text = item.text - break - if user_text is None: - raise ValueError("No text content found in the request messages") + user_text = extract_user_text(raw_request.messages) prompt = template.replace("", user_text) diff --git a/examples/multimodal/utils/chat_message_utils.py b/examples/multimodal/utils/chat_message_utils.py new file mode 100644 index 00000000000..a1c1cdac81e --- /dev/null +++ b/examples/multimodal/utils/chat_message_utils.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for processing chat messages.""" + + +def extract_user_text(messages) -> str: + """Extract and concatenate text content from user messages.""" + user_texts = [] + for message in messages: + if message.role == "user": + # Collect all text content items from this user message + text_parts = [] + for item in message.content: + if item.type == "text" and item.text: + text_parts.append(item.text) + # If this user message has text content, join it and add to user_texts + if text_parts: + user_texts.append("".join(text_parts)) + + if not user_texts: + raise ValueError("No text content found in user messages") + + # Join all user turns with newline separator + return "\n".join(user_texts) diff --git a/examples/multimodal/utils/protocol.py b/examples/multimodal/utils/protocol.py index 4f836f6728d..06e33c92f4a 100644 --- a/examples/multimodal/utils/protocol.py +++ b/examples/multimodal/utils/protocol.py @@ -18,7 +18,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union import msgspec -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic_core import core_schema from typing_extensions import NotRequired from vllm.inputs.data import TokensPrompt @@ -89,9 +89,13 @@ def parse_sampling_params(cls, v: Any) -> SamplingParams: return SamplingParams(**v) return v + @field_serializer("sampling_params") + def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]: + """Serialize SamplingParams using msgspec and return as dict.""" + return json.loads(msgspec.json.encode(value)) + model_config = ConfigDict( arbitrary_types_allowed=True, - json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}, )