diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 48d053e60c..cffc8b9d8d 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -294,6 +294,9 @@ class ChatMessage(BaseModel): role: MessageRole = MessageRole.USER additional_kwargs: dict[str, Any] = Field(default_factory=dict) blocks: list[ContentBlock] = Field(default_factory=list) + id: Optional[str] = Field( + default=None, description="Optional unique identifier for the message" + ) def __init__(self, /, content: Any | None = None, **data: Any) -> None: """ diff --git a/llama-index-core/llama_index/core/memory/memory.py b/llama-index-core/llama_index/core/memory/memory.py index 57d9358838..17cad9dbac 100644 --- a/llama-index-core/llama_index/core/memory/memory.py +++ b/llama-index-core/llama_index/core/memory/memory.py @@ -726,6 +726,10 @@ async def _manage_queue(self) -> None: async def aput(self, message: ChatMessage) -> None: """Add a message to the chat store and process waterfall logic if needed.""" + # Ensure message has an ID if not provided + if message.id is None: + message.id = str(uuid.uuid4()) + # Add the message to the chat store await self.sql_store.add_message( self.session_id, message, status=MessageStatus.ACTIVE @@ -736,6 +740,11 @@ async def aput(self, message: ChatMessage) -> None: async def aput_messages(self, messages: List[ChatMessage]) -> None: """Add a list of messages to the chat store and process waterfall logic if needed.""" + # Ensure all messages have IDs if not provided + for message in messages: + if message.id is None: + message.id = str(uuid.uuid4()) + # Add the messages to the chat store await self.sql_store.add_messages( self.session_id, messages, status=MessageStatus.ACTIVE @@ -746,6 +755,11 @@ async def aput_messages(self, messages: List[ChatMessage]) -> None: async def aset(self, messages: List[ChatMessage]) -> None: """Set the chat history.""" + # Ensure all messages have IDs if not provided + for message in messages: + if message.id is None: + message.id = str(uuid.uuid4()) + await self.sql_store.set_messages( self.session_id, messages, status=MessageStatus.ACTIVE ) @@ -756,6 +770,26 @@ async def aget_all( """Get all messages.""" return await self.sql_store.get_messages(self.session_id, status=status) + async def aget_by_id( + self, message_id: str, status: Optional[MessageStatus] = None + ) -> Optional[ChatMessage]: + """ + Get a specific message by its ID. + + Args: + message_id: The ID of the message to retrieve + status: Filter by message status (active, archived, etc.) + + Returns: + The message if found, None otherwise + + """ + all_messages = await self.sql_store.get_messages(self.session_id, status=status) + for message in all_messages: + if message.id == message_id: + return message + return None + async def areset(self, status: Optional[MessageStatus] = None) -> None: """Reset the memory.""" await self.sql_store.delete_messages(self.session_id, status=status) @@ -770,6 +804,22 @@ def get_all(self, status: Optional[MessageStatus] = None) -> List[ChatMessage]: """Get all messages.""" return asyncio_run(self.aget_all(status=status)) + def get_by_id( + self, message_id: str, status: Optional[MessageStatus] = None + ) -> Optional[ChatMessage]: + """ + Get a specific message by its ID. + + Args: + message_id: The ID of the message to retrieve + status: Filter by message status (active, archived, etc.) + + Returns: + The message if found, None otherwise + + """ + return asyncio_run(self.aget_by_id(message_id, status=status)) + def put(self, message: ChatMessage) -> None: """Add a message to the chat store and process waterfall logic if needed.""" return asyncio_run(self.aput(message)) diff --git a/llama-index-core/llama_index/core/memory/types.py b/llama-index-core/llama_index/core/memory/types.py index 8a94739b23..bc73136a82 100644 --- a/llama-index-core/llama_index/core/memory/types.py +++ b/llama-index-core/llama_index/core/memory/types.py @@ -45,6 +45,36 @@ async def aget_all(self) -> List[ChatMessage]: """Get all chat history.""" return await asyncio.to_thread(self.get_all) + def get_by_id(self, message_id: str) -> Optional[ChatMessage]: + """ + Get a specific message by its ID. + + Args: + message_id: The ID of the message to retrieve + + Returns: + The message if found, None otherwise + + """ + all_messages = self.get_all() + for message in all_messages: + if message.id == message_id: + return message + return None + + async def aget_by_id(self, message_id: str) -> Optional[ChatMessage]: + """ + Get a specific message by its ID (async). + + Args: + message_id: The ID of the message to retrieve + + Returns: + The message if found, None otherwise + + """ + return await asyncio.to_thread(self.get_by_id, message_id) + @abstractmethod def put(self, message: ChatMessage) -> None: """Put chat history.""" diff --git a/llama-index-core/tests/base/llms/test_types.py b/llama-index-core/tests/base/llms/test_types.py index 93ed7da20d..f901ec39e5 100644 --- a/llama-index-core/tests/base/llms/test_types.py +++ b/llama-index-core/tests/base/llms/test_types.py @@ -120,6 +120,7 @@ class SimpleModel(BaseModel): content="test content", additional_kwargs={"some_list": ["a", "b", "c"], "some_object": SimpleModel()}, ) + temp_str = str(m.model_dump()) assert m.model_dump() == { "role": MessageRole.USER, "additional_kwargs": { @@ -127,6 +128,7 @@ class SimpleModel(BaseModel): "some_object": {"some_field": ""}, }, "blocks": [{"block_type": "text", "text": "test content"}], + "id": None, } @@ -141,6 +143,7 @@ def test_chat_message_legacy_roundtrip(): "additional_kwargs": {}, "blocks": [{"block_type": "text", "text": "foo"}], "role": MessageRole.USER, + "id": None, } diff --git a/llama-index-core/tests/memory/test_message_id.py b/llama-index-core/tests/memory/test_message_id.py new file mode 100644 index 0000000000..7d9999dd2f --- /dev/null +++ b/llama-index-core/tests/memory/test_message_id.py @@ -0,0 +1,77 @@ +"""Test message ID functionality in memory components.""" + +import pytest + +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.memory.memory import Memory + + +def test_chat_message_id(): + """Test that ChatMessage can have an ID set.""" + # Create message with ID + message_with_id = ChatMessage( + content="Test message", role=MessageRole.USER, id="test-id-123" + ) + assert message_with_id.id == "test-id-123" + + # Create message without ID + message_without_id = ChatMessage(content="Test message", role=MessageRole.USER) + assert message_without_id.id is None + + +@pytest.mark.asyncio +async def test_memory_message_id(): + """Test that Memory handles message IDs properly.""" + memory = Memory.from_defaults() + + # Add message with ID + message1 = ChatMessage(content="Message 1", role=MessageRole.USER, id="msg-1") + await memory.aput(message1) + + # Add message without ID (should get auto-assigned) + message2 = ChatMessage(content="Message 2", role=MessageRole.ASSISTANT) + await memory.aput(message2) + + # Verify message2 got an ID + messages = await memory.aget_all() + assert len(messages) == 2 + assert messages[0].id == "msg-1" + assert messages[1].id is not None + assert messages[1].content == "Message 2" + + # Get message by ID + retrieved_message = await memory.aget_by_id("msg-1") + assert retrieved_message is not None + assert retrieved_message.content == "Message 1" + + # Try getting message with non-existent ID + non_existent = await memory.aget_by_id("non-existent-id") + assert non_existent is None + + +@pytest.mark.asyncio +async def test_memory_multiple_messages(): + """Test that Memory.aput_messages assigns IDs to multiple messages.""" + memory = Memory.from_defaults() + + # Create messages + messages = [ + ChatMessage(content="Message 1", role=MessageRole.USER, id="msg-1"), + ChatMessage(content="Message 2", role=MessageRole.ASSISTANT), # No ID + ChatMessage(content="Message 3", role=MessageRole.USER), # No ID + ] + + # Add messages + await memory.aput_messages(messages) + + # Verify messages + retrieved = await memory.aget_all() + assert len(retrieved) == 3 + assert retrieved[0].id == "msg-1" + assert retrieved[1].id is not None # Auto-assigned + assert retrieved[2].id is not None # Auto-assigned + + # Get message by ID + msg1 = await memory.aget_by_id("msg-1") + assert msg1 is not None + assert msg1.content == "Message 1"