Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ class ChatMessage(BaseModel):
role: MessageRole = MessageRole.USER
additional_kwargs: dict[str, Any] = Field(default_factory=dict)
blocks: list[ContentBlock] = Field(default_factory=list)
id: Optional[str] = Field(
default=None, description="Optional unique identifier for the message"
)

def __init__(self, /, content: Any | None = None, **data: Any) -> None:
"""
Expand Down
50 changes: 50 additions & 0 deletions llama-index-core/llama_index/core/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ async def _manage_queue(self) -> None:

async def aput(self, message: ChatMessage) -> None:
"""Add a message to the chat store and process waterfall logic if needed."""
# Ensure message has an ID if not provided
if message.id is None:
message.id = str(uuid.uuid4())

# Add the message to the chat store
await self.sql_store.add_message(
self.session_id, message, status=MessageStatus.ACTIVE
Expand All @@ -736,6 +740,11 @@ async def aput(self, message: ChatMessage) -> None:

async def aput_messages(self, messages: List[ChatMessage]) -> None:
"""Add a list of messages to the chat store and process waterfall logic if needed."""
# Ensure all messages have IDs if not provided
for message in messages:
if message.id is None:
message.id = str(uuid.uuid4())

# Add the messages to the chat store
await self.sql_store.add_messages(
self.session_id, messages, status=MessageStatus.ACTIVE
Expand All @@ -746,6 +755,11 @@ async def aput_messages(self, messages: List[ChatMessage]) -> None:

async def aset(self, messages: List[ChatMessage]) -> None:
"""Set the chat history."""
# Ensure all messages have IDs if not provided
for message in messages:
if message.id is None:
message.id = str(uuid.uuid4())

await self.sql_store.set_messages(
self.session_id, messages, status=MessageStatus.ACTIVE
)
Expand All @@ -756,6 +770,26 @@ async def aget_all(
"""Get all messages."""
return await self.sql_store.get_messages(self.session_id, status=status)

async def aget_by_id(
self, message_id: str, status: Optional[MessageStatus] = None
) -> Optional[ChatMessage]:
"""
Get a specific message by its ID.

Args:
message_id: The ID of the message to retrieve
status: Filter by message status (active, archived, etc.)

Returns:
The message if found, None otherwise

"""
all_messages = await self.sql_store.get_messages(self.session_id, status=status)
for message in all_messages:
if message.id == message_id:
return message
return None

async def areset(self, status: Optional[MessageStatus] = None) -> None:
"""Reset the memory."""
await self.sql_store.delete_messages(self.session_id, status=status)
Expand All @@ -770,6 +804,22 @@ def get_all(self, status: Optional[MessageStatus] = None) -> List[ChatMessage]:
"""Get all messages."""
return asyncio_run(self.aget_all(status=status))

def get_by_id(
self, message_id: str, status: Optional[MessageStatus] = None
) -> Optional[ChatMessage]:
"""
Get a specific message by its ID.

Args:
message_id: The ID of the message to retrieve
status: Filter by message status (active, archived, etc.)

Returns:
The message if found, None otherwise

"""
return asyncio_run(self.aget_by_id(message_id, status=status))

def put(self, message: ChatMessage) -> None:
"""Add a message to the chat store and process waterfall logic if needed."""
return asyncio_run(self.aput(message))
Expand Down
30 changes: 30 additions & 0 deletions llama-index-core/llama_index/core/memory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ async def aget_all(self) -> List[ChatMessage]:
"""Get all chat history."""
return await asyncio.to_thread(self.get_all)

def get_by_id(self, message_id: str) -> Optional[ChatMessage]:
"""
Get a specific message by its ID.

Args:
message_id: The ID of the message to retrieve

Returns:
The message if found, None otherwise

"""
all_messages = self.get_all()
for message in all_messages:
if message.id == message_id:
return message
return None

async def aget_by_id(self, message_id: str) -> Optional[ChatMessage]:
"""
Get a specific message by its ID (async).

Args:
message_id: The ID of the message to retrieve

Returns:
The message if found, None otherwise

"""
return await asyncio.to_thread(self.get_by_id, message_id)

@abstractmethod
def put(self, message: ChatMessage) -> None:
"""Put chat history."""
Expand Down
3 changes: 3 additions & 0 deletions llama-index-core/tests/base/llms/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ class SimpleModel(BaseModel):
content="test content",
additional_kwargs={"some_list": ["a", "b", "c"], "some_object": SimpleModel()},
)
temp_str = str(m.model_dump())
assert m.model_dump() == {
"role": MessageRole.USER,
"additional_kwargs": {
"some_list": ["a", "b", "c"],
"some_object": {"some_field": ""},
},
"blocks": [{"block_type": "text", "text": "test content"}],
"id": None,
}


Expand All @@ -141,6 +143,7 @@ def test_chat_message_legacy_roundtrip():
"additional_kwargs": {},
"blocks": [{"block_type": "text", "text": "foo"}],
"role": MessageRole.USER,
"id": None,
}


Expand Down
77 changes: 77 additions & 0 deletions llama-index-core/tests/memory/test_message_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Test message ID functionality in memory components."""

import pytest

from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.memory.memory import Memory


def test_chat_message_id():
"""Test that ChatMessage can have an ID set."""
# Create message with ID
message_with_id = ChatMessage(
content="Test message", role=MessageRole.USER, id="test-id-123"
)
assert message_with_id.id == "test-id-123"

# Create message without ID
message_without_id = ChatMessage(content="Test message", role=MessageRole.USER)
assert message_without_id.id is None


@pytest.mark.asyncio
async def test_memory_message_id():
"""Test that Memory handles message IDs properly."""
memory = Memory.from_defaults()

# Add message with ID
message1 = ChatMessage(content="Message 1", role=MessageRole.USER, id="msg-1")
await memory.aput(message1)

# Add message without ID (should get auto-assigned)
message2 = ChatMessage(content="Message 2", role=MessageRole.ASSISTANT)
await memory.aput(message2)

# Verify message2 got an ID
messages = await memory.aget_all()
assert len(messages) == 2
assert messages[0].id == "msg-1"
assert messages[1].id is not None
assert messages[1].content == "Message 2"

# Get message by ID
retrieved_message = await memory.aget_by_id("msg-1")
assert retrieved_message is not None
assert retrieved_message.content == "Message 1"

# Try getting message with non-existent ID
non_existent = await memory.aget_by_id("non-existent-id")
assert non_existent is None


@pytest.mark.asyncio
async def test_memory_multiple_messages():
"""Test that Memory.aput_messages assigns IDs to multiple messages."""
memory = Memory.from_defaults()

# Create messages
messages = [
ChatMessage(content="Message 1", role=MessageRole.USER, id="msg-1"),
ChatMessage(content="Message 2", role=MessageRole.ASSISTANT), # No ID
ChatMessage(content="Message 3", role=MessageRole.USER), # No ID
]

# Add messages
await memory.aput_messages(messages)

# Verify messages
retrieved = await memory.aget_all()
assert len(retrieved) == 3
assert retrieved[0].id == "msg-1"
assert retrieved[1].id is not None # Auto-assigned
assert retrieved[2].id is not None # Auto-assigned

# Get message by ID
msg1 = await memory.aget_by_id("msg-1")
assert msg1 is not None
assert msg1.content == "Message 1"
Loading