diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 330a66dc10..4daef8e76e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -790,9 +790,9 @@ async def run_agent_stream( # Create session (with service session support) if config.use_service_session: supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId") - session = AgentSession(service_session_id=supplied_thread_id) + session = AgentSession(session_id=thread_id, service_session_id=supplied_thread_id) else: - session = AgentSession() + session = AgentSession(session_id=thread_id) # Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation) base_metadata: dict[str, Any] = { diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 744196dbdf..64ac8e9d66 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -183,6 +183,7 @@ def __init__( self.client = client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None + self.last_session: AgentSession | None = None @overload def run( @@ -216,6 +217,7 @@ def run( async def _stream() -> AsyncIterator[AgentResponseUpdate]: self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] + self.last_session = session self.tools_received = kwargs.get("tools") for update in self.updates: yield update diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index 392f0cd723..70af4c064a 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -1640,3 +1640,115 @@ def test_reasoning_distinct_ids_close_previous_block(self): # close: MsgEnd(block2) + End(block2) assert isinstance(close[0], ReasoningMessageEndEvent) assert close[0].message_id == "block2" + + +async def test_session_id_matches_thread_id(): + """Session created by run_agent_stream uses the client thread_id as session_id.""" + from conftest import StubAgent + + from agent_framework_ag_ui import AgentFrameworkAgent + + stub = StubAgent() + agent = AgentFrameworkAgent(agent=stub) + + payload = { + "thread_id": "my-thread-123", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Hello"}], + } + + _ = [event async for event in agent.run(payload)] + + assert stub.last_session is not None + assert stub.last_session.session_id == "my-thread-123" + + +async def test_session_id_matches_camel_case_thread_id(): + """Session uses threadId (camelCase) as session_id when snake_case is absent.""" + from conftest import StubAgent + + from agent_framework_ag_ui import AgentFrameworkAgent + + stub = StubAgent() + agent = AgentFrameworkAgent(agent=stub) + + payload = { + "threadId": "camel-thread-456", + "run_id": "run-2", + "messages": [{"role": "user", "content": "Hello"}], + } + + _ = [event async for event in agent.run(payload)] + + assert stub.last_session is not None + assert stub.last_session.session_id == "camel-thread-456" + + +async def test_session_id_matches_thread_id_with_service_session(): + """Session uses thread_id as session_id even when use_service_session is enabled.""" + from conftest import StubAgent + + from agent_framework_ag_ui import AgentFrameworkAgent + + stub = StubAgent() + agent = AgentFrameworkAgent(agent=stub, use_service_session=True) + + payload = { + "thread_id": "service-thread-789", + "run_id": "run-3", + "messages": [{"role": "user", "content": "Hello"}], + } + + _ = [event async for event in agent.run(payload)] + + assert stub.last_session is not None + assert stub.last_session.session_id == "service-thread-789" + assert stub.last_session.service_session_id == "service-thread-789" + + +async def test_session_id_generated_when_no_thread_id(): + """Session gets a generated UUID as session_id when no thread_id is provided.""" + import uuid + + from conftest import StubAgent + + from agent_framework_ag_ui import AgentFrameworkAgent + + stub = StubAgent() + agent = AgentFrameworkAgent(agent=stub) + + payload = { + "run_id": "run-4", + "messages": [{"role": "user", "content": "Hello"}], + } + + _ = [event async for event in agent.run(payload)] + + assert stub.last_session is not None + # Should be a valid UUID (auto-generated) + uuid.UUID(stub.last_session.session_id) + + +async def test_service_session_no_thread_id_generates_uuid(): + """With use_service_session=True and no thread_id, session_id is a UUID and service_session_id is None.""" + import uuid + + from conftest import StubAgent + + from agent_framework_ag_ui import AgentFrameworkAgent + + stub = StubAgent() + agent = AgentFrameworkAgent(agent=stub, use_service_session=True) + + payload = { + "run_id": "run-5", + "messages": [{"role": "user", "content": "Hello"}], + } + + _ = [event async for event in agent.run(payload)] + + assert stub.last_session is not None + # session_id should be a valid auto-generated UUID + uuid.UUID(stub.last_session.session_id) + # service_session_id should be None since no thread_id was supplied + assert stub.last_session.service_session_id is None diff --git a/python/samples/02-agents/conversations/file_history_provider.py b/python/samples/02-agents/conversations/file_history_provider.py index 04a87f8224..20735ffd17 100644 --- a/python/samples/02-agents/conversations/file_history_provider.py +++ b/python/samples/02-agents/conversations/file_history_provider.py @@ -21,7 +21,7 @@ from pydantic import Field try: - import orjson + import orjson # pyright: ignore[reportMissingImports] except ImportError: orjson = None diff --git a/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py b/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py index 70c5d7e8e8..693501b0f9 100644 --- a/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py +++ b/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py @@ -22,7 +22,7 @@ from pydantic import Field try: - import orjson + import orjson # pyright: ignore[reportMissingImports] except ImportError: orjson = None