Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MultiModalRequest,
MyRequestOutput,
ProcessMixIn,
extract_user_text,
vLLMMultimodalRequest,
)

Expand Down Expand Up @@ -156,10 +157,7 @@ async def generate(self, raw_request: MultiModalRequest, context):
raise ValueError("prompt_template must contain '<prompt>' placeholder")

# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
user_text = extract_user_text(raw_request.messages)

prompt = template.replace("<prompt>", user_text)

Expand Down
2 changes: 2 additions & 0 deletions components/src/dynamo/vllm/multimodal_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.chat_processor import (
ChatProcessor,
CompletionsProcessor,
Expand Down Expand Up @@ -31,6 +32,7 @@
"CompletionsProcessor",
"ProcessMixIn",
"encode_image_embeddings",
"extract_user_text",
"get_encoder_components",
"get_http_client",
"ImageLoader",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Utility functions for processing chat messages."""

from typing import List

from dynamo.vllm.multimodal_utils.protocol import ChatMessage


def extract_user_text(messages: List[ChatMessage]) -> str:
"""Extract and concatenate text content from user messages."""

# This function finds all text content items from "user" role messages,
# and concatenates them. For multi-turn conversation, it adds a newline
# between each turn. This is not a perfect solution as we encode multi-turn
# conversation as a single turn. However, multi-turn conversation in a
# single request is not well defined in the spec.

# TODO: Revisit this later when adding multi-turn conversation support.
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))

if not user_texts:
raise ValueError("No text content found in user messages")

# Join all user turns with newline separator
return "\n".join(user_texts)
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ class ProcessMixIn(ProcessMixInRequired):
Mixin for pre and post processing for vLLM
"""

engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams

def __init__(self):
pass

Expand Down
8 changes: 6 additions & 2 deletions components/src/dynamo/vllm/multimodal_utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union

import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
Expand Down Expand Up @@ -89,9 +89,13 @@ def parse_sampling_params(cls, v: Any) -> SamplingParams:
return SamplingParams(**v)
return v

@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))

model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: json.loads(msgspec.json.encode(v))},
)


Expand Down
149 changes: 149 additions & 0 deletions components/src/dynamo/vllm/tests/test_chat_message_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Unit tests for chat message utility functions."""

import pytest

from dynamo.vllm.multimodal_utils.chat_message_utils import extract_user_text
from dynamo.vllm.multimodal_utils.protocol import (
ChatMessage,
ImageContent,
ImageURLDetail,
TextContent,
)

pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]


def test_extract_user_text_single_message():
"""Test extracting text from a single user message with one text content."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="Hello, world!")]
)
]

result = extract_user_text(messages)
assert result == "Hello, world!"


def test_extract_user_text_multiple_text_parts():
"""Test extracting text from a user message with multiple text content items."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text="First part "),
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
),
TextContent(type="text", text="second part"),
],
)
]

result = extract_user_text(messages)
assert result == "First part second part"


def test_extract_user_text_multi_turn():
"""Test extracting text from multi-turn conversation."""
messages = [
ChatMessage(
role="user", content=[TextContent(type="text", text="First question")]
),
ChatMessage(
role="assistant", content=[TextContent(type="text", text="First answer")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="Second question")]
),
]

result = extract_user_text(messages)
assert result == "First question\nSecond question"


def test_extract_user_text_only_images():
"""Test that ValueError is raised when messages contain only images."""
messages = [
ChatMessage(
role="user",
content=[
ImageContent(
type="image_url",
image_url=ImageURLDetail(url="http://example.com/image.jpg"),
)
],
)
]

with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)


def test_extract_user_text_empty_messages():
"""Test that ValueError is raised when messages list is empty."""
messages: list[ChatMessage] = []

with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)


def test_extract_user_text_no_user_messages():
"""Test that ValueError is raised when there are no user role messages."""
messages = [
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Just an assistant message")],
)
]

with pytest.raises(ValueError, match="No text content found in user messages"):
extract_user_text(messages)


def test_extract_user_text_mixed_roles():
"""Test extracting text only from user messages, ignoring other roles."""
messages = [
ChatMessage(
role="system", content=[TextContent(type="text", text="System prompt")]
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 1")]
),
ChatMessage(
role="assistant",
content=[TextContent(type="text", text="Assistant response")],
),
ChatMessage(
role="user", content=[TextContent(type="text", text="User message 2")]
),
]

result = extract_user_text(messages)
assert result == "User message 1\nUser message 2"


def test_extract_user_text_empty_text_content():
"""Test that empty text content items are ignored."""
messages = [
ChatMessage(
role="user",
content=[
TextContent(type="text", text=""),
TextContent(type="text", text="Valid text"),
TextContent(type="text", text=""),
],
)
]

result = extract_user_text(messages)
assert result == "Valid text"
11 changes: 2 additions & 9 deletions examples/multimodal/components/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_message_utils import extract_user_text
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import (
MultiModalInput,
Expand Down Expand Up @@ -203,15 +204,7 @@ async def generate(self, raw_request: MultiModalRequest):
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")

# Safely extract user text - find the text content item
user_text = None
for message in raw_request.messages:
for item in message.content:
if item.type == "text":
user_text = item.text
break
if user_text is None:
raise ValueError("No text content found in the request messages")
user_text = extract_user_text(raw_request.messages)

prompt = template.replace("<prompt>", user_text)

Expand Down
25 changes: 25 additions & 0 deletions examples/multimodal/utils/chat_message_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Utility functions for processing chat messages."""


def extract_user_text(messages) -> str:
"""Extract and concatenate text content from user messages."""
user_texts = []
for message in messages:
if message.role == "user":
# Collect all text content items from this user message
text_parts = []
for item in message.content:
if item.type == "text" and item.text:
text_parts.append(item.text)
# If this user message has text content, join it and add to user_texts
if text_parts:
user_texts.append("".join(text_parts))

if not user_texts:
raise ValueError("No text content found in user messages")

# Join all user turns with newline separator
return "\n".join(user_texts)
8 changes: 6 additions & 2 deletions examples/multimodal/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union

import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
Expand Down Expand Up @@ -89,9 +89,13 @@ def parse_sampling_params(cls, v: Any) -> SamplingParams:
return SamplingParams(**v)
return v

@field_serializer("sampling_params")
def serialize_sampling_params(self, value: SamplingParams) -> dict[str, Any]:
"""Serialize SamplingParams using msgspec and return as dict."""
return json.loads(msgspec.json.encode(value))

model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)},
)


Expand Down
Loading