-
Couldn't load subscription status.
- Fork 612
Python: Add Cosmos DB ChatMessageStore sample example for external chat history #1533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aman-panjwani
wants to merge
5
commits into
microsoft:main
Choose a base branch
from
aman-panjwani:feature/cosmosdb-chat-store
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
aab6bc8
Add Cosmos DB ChatMessageStore example for external chat history
16bb4bb
Minor Updates
3f325c8
Merge branch 'microsoft:main' into feature/cosmosdb-chat-store
aman-panjwani e9645ad
Editorial updates and cleanup
daa5a07
Merge branch 'feature/cosmosdb-chat-store' of https://github.com/aman…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
237 changes: 237 additions & 0 deletions
237
python/samples/getting_started/chat_store/third_party_chat_store_cosmosDB.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,237 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| """Cosmos DB Chat Message Store Example | ||
|
|
||
| Demonstrates how to store and retrieve chat history using Azure Cosmos DB | ||
| as an external message store for the Microsoft Agent Framework. | ||
|
|
||
| Scenarios: | ||
| 1) Persist chat messages in Cosmos DB with thread-based partitioning. | ||
| 2) Retrieve messages in chronological order for conversation continuity. | ||
| 3) Serialize and deserialize thread state for persistence across sessions. | ||
| 4) Properly close async Cosmos DB and chat client sessions. | ||
|
|
||
| Requirements: | ||
| - Azure Cosmos DB (Core SQL API) with an existing database and container. | ||
| - Container partition key must be /thread_id. | ||
| - Environment variables: | ||
| COSMOS_DB_ENDPOINT, COSMOS_DB_KEY, | ||
| AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, | ||
| AZURE_OPENAI_API_VERSION, AZURE_OPENAI_CHAT_DEPLOYMENT_NAME | ||
| - Dependencies: | ||
| pip install azure-cosmos pydantic "agent-framework" | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Run: | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| python third_party_chat_store_cosmosDB.py | ||
| """ | ||
|
|
||
| import os | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from time import time | ||
| from uuid import uuid4 | ||
| from typing import Sequence, Any | ||
| import enum, datetime, uuid | ||
|
|
||
| from pydantic import BaseModel | ||
| from azure.cosmos.aio import CosmosClient | ||
|
|
||
| from agent_framework import ChatMessage, ChatAgent | ||
| from agent_framework.azure import AzureOpenAIChatClient | ||
|
|
||
|
|
||
| class CosmosDBStoreState(BaseModel): | ||
| """Serializable state for CosmosDB chat message store.""" | ||
| thread_id: str | ||
| cosmos_endpoint: str | None = None | ||
| database_name: str = "agent_framework" | ||
| container_name: str = "chat_messages" | ||
|
|
||
|
|
||
| class CosmosDBChatMessageStore: | ||
| """ | ||
| Lightweight Cosmos DB-backed chat history store for Microsoft Agent Framework. | ||
|
|
||
| This implementation: | ||
| - Uses key-based authentication. | ||
| - Stores one conversation thread per partition (partition key = /thread_id). | ||
| - Appends messages and retrieves them chronologically. | ||
| - Supports serialization/deserialization for persistent AgentThread state. | ||
| - Assumes database and container already exist. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| cosmos_endpoint: str, | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| cosmos_key: str, | ||
| *, | ||
| thread_id: str | None = None, | ||
| database_name: str = "agent_framework", | ||
| container_name: str = "chat_messages", | ||
| ) -> None: | ||
| if not cosmos_endpoint: | ||
| raise ValueError("cosmos_endpoint is required") | ||
| if not cosmos_key: | ||
| raise ValueError("cosmos_key is required") | ||
|
|
||
| self.cosmos_endpoint = cosmos_endpoint | ||
| self.cosmos_key = cosmos_key | ||
| self.thread_id = thread_id or f"thread_{uuid4()}" | ||
| self.database_name = database_name | ||
| self.container_name = container_name | ||
|
|
||
| self._client: CosmosClient | None = None | ||
| self._container = None | ||
| self._ready = False | ||
|
|
||
| async def _ensure(self) -> None: | ||
| """Initialize the Cosmos DB client and container reference if not already set.""" | ||
| if self._ready: | ||
| return | ||
| self._client = CosmosClient(self.cosmos_endpoint, self.cosmos_key) | ||
| db = self._client.get_database_client(self.database_name) | ||
| self._container = db.get_container_client(self.container_name) | ||
| self._ready = True | ||
|
|
||
| async def add_messages(self, messages: Sequence[ChatMessage]) -> None: | ||
| """Persist new chat messages for the current thread.""" | ||
| if not messages: | ||
| return | ||
| await self._ensure() | ||
| for msg in messages: | ||
| doc = { | ||
| "id": f"{self.thread_id}_{uuid4()}", | ||
| "thread_id": self.thread_id, | ||
| "ts": time(), | ||
| "message": self._to_dict(msg), | ||
| } | ||
| await self._container.upsert_item(doc) | ||
|
|
||
| async def list_messages(self) -> list[ChatMessage]: | ||
| """Retrieve all chat messages for the thread in chronological order.""" | ||
| await self._ensure() | ||
| query = """ | ||
| SELECT c.message FROM c | ||
| WHERE c.thread_id = @thread_id | ||
| ORDER BY c.ts ASC | ||
| """ | ||
| params = [{"name": "@thread_id", "value": self.thread_id}] | ||
| messages: list[ChatMessage] = [] | ||
| async for row in self._container.query_items( | ||
| query=query, parameters=params, partition_key=self.thread_id | ||
| ): | ||
| messages.append(self._from_dict(row["message"])) | ||
| return messages | ||
|
|
||
| async def clear(self) -> None: | ||
| """Delete all chat messages for the current thread.""" | ||
| await self._ensure() | ||
| query = "SELECT c.id FROM c WHERE c.thread_id = @thread_id" | ||
| params = [{"name": "@thread_id", "value": self.thread_id}] | ||
| async for row in self._container.query_items( | ||
| query=query, parameters=params, partition_key=self.thread_id | ||
| ): | ||
| await self._container.delete_item(row["id"], partition_key=self.thread_id) | ||
|
|
||
| async def aclose(self) -> None: | ||
| """Close the Cosmos DB client connection.""" | ||
| if self._client: | ||
| await self._client.close() | ||
| self._ready = False | ||
|
|
||
| async def serialize_state(self, **kwargs: Any) -> dict: | ||
| """Serialize the store configuration for persistent thread state.""" | ||
| return CosmosDBStoreState( | ||
| thread_id=self.thread_id, | ||
| cosmos_endpoint=self.cosmos_endpoint, | ||
| database_name=self.database_name, | ||
| container_name=self.container_name, | ||
| ).model_dump(**kwargs) | ||
|
|
||
| async def deserialize_state(self, state: dict | None, **_: Any) -> None: | ||
| """Restore store configuration from serialized thread state.""" | ||
| if not state: | ||
| return | ||
| s = CosmosDBStoreState.model_validate(state) | ||
| self.thread_id = s.thread_id | ||
| if (s.cosmos_endpoint and s.cosmos_endpoint != self.cosmos_endpoint) or \ | ||
| s.database_name != self.database_name or \ | ||
| s.container_name != self.container_name: | ||
| self.cosmos_endpoint = s.cosmos_endpoint or self.cosmos_endpoint | ||
| self.database_name = s.database_name | ||
| self.container_name = s.container_name | ||
| self._ready = False | ||
|
|
||
| def _to_dict(self, message: ChatMessage) -> dict: | ||
| """Convert ChatMessage into a JSON-safe dictionary.""" | ||
| def make_safe(obj): | ||
| if obj is None: | ||
| return None | ||
| if isinstance(obj, (str, int, float, bool)): | ||
| return obj | ||
| if isinstance(obj, enum.Enum): | ||
| return obj.value | ||
| if isinstance(obj, (datetime.datetime, datetime.date)): | ||
| return obj.isoformat() | ||
| if isinstance(obj, uuid.UUID): | ||
| return str(obj) | ||
| if isinstance(obj, list): | ||
| return [make_safe(x) for x in obj] | ||
| if isinstance(obj, dict): | ||
| return {k: make_safe(v) for k, v in obj.items()} | ||
| if hasattr(obj, "dict"): | ||
| return make_safe(obj.dict()) | ||
| if hasattr(obj, "model_dump"): | ||
| return make_safe(obj.model_dump()) | ||
| if hasattr(obj, "__dict__"): | ||
| return make_safe(vars(obj)) | ||
| return str(obj) | ||
|
|
||
| if hasattr(message, "model_dump"): | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raw = message.model_dump() | ||
| elif hasattr(message, "dict"): | ||
| raw = message.dict() | ||
| else: | ||
| raw = vars(message) | ||
| return make_safe(raw) | ||
|
|
||
| def _from_dict(self, data: dict) -> ChatMessage: | ||
| """Reconstruct a ChatMessage from a stored dictionary.""" | ||
| if hasattr(ChatMessage, "model_validate"): | ||
aman-panjwani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return ChatMessage.model_validate(data) | ||
| if hasattr(ChatMessage, "parse_obj"): | ||
| return ChatMessage.parse_obj(data) | ||
| return ChatMessage(**data) | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| """Demonstration of CosmosDBChatMessageStore with ChatAgent.""" | ||
| store = CosmosDBChatMessageStore( | ||
| cosmos_endpoint=os.getenv("COSMOS_DB_ENDPOINT"), | ||
| cosmos_key=os.getenv("COSMOS_DB_KEY"), | ||
| database_name="agent-chat-conversation", | ||
| container_name="chat_messages", | ||
| ) | ||
|
|
||
| chat_client = AzureOpenAIChatClient( | ||
| model_id=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"), | ||
| api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||
| endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||
| api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | ||
| ) | ||
|
|
||
| agent = ChatAgent( | ||
| chat_client=chat_client, | ||
| name="Joker", | ||
| instructions="You are good at telling jokes.", | ||
| chat_message_store_factory=lambda: store, | ||
| ) | ||
|
|
||
| try: | ||
| thread = agent.get_new_thread() | ||
| await agent.run("Tell me a pirate joke.", thread=thread) | ||
| await agent.run("One more!", thread=thread) | ||
| finally: | ||
| await store.aclose() | ||
|
|
||
| if __name__ == "__main__": | ||
| import asyncio | ||
| asyncio.run(main()) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.