Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ async def run_agent_stream(
# Create session (with service session support)
if config.use_service_session:
supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId")
session = AgentSession(service_session_id=supplied_thread_id)
session = AgentSession(session_id=thread_id, service_session_id=supplied_thread_id)
else:
session = AgentSession()
session = AgentSession(session_id=thread_id)

# Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation)
base_metadata: dict[str, Any] = {
Expand Down
2 changes: 2 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
self.client = client or SimpleNamespace(function_invocation_configuration=None)
self.messages_received: list[Any] = []
self.tools_received: list[Any] | None = None
self.last_session: AgentSession | None = None

@overload
def run(
Expand Down Expand Up @@ -216,6 +217,7 @@ def run(

async def _stream() -> AsyncIterator[AgentResponseUpdate]:
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
self.last_session = session
self.tools_received = kwargs.get("tools")
for update in self.updates:
yield update
Expand Down
112 changes: 112 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,3 +1640,115 @@ def test_reasoning_distinct_ids_close_previous_block(self):
# close: MsgEnd(block2) + End(block2)
assert isinstance(close[0], ReasoningMessageEndEvent)
assert close[0].message_id == "block2"


async def test_session_id_matches_thread_id():
"""Session created by run_agent_stream uses the client thread_id as session_id."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"thread_id": "my-thread-123",
"run_id": "run-1",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "my-thread-123"


async def test_session_id_matches_camel_case_thread_id():
"""Session uses threadId (camelCase) as session_id when snake_case is absent."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"threadId": "camel-thread-456",
"run_id": "run-2",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "camel-thread-456"


async def test_session_id_matches_thread_id_with_service_session():
"""Session uses thread_id as session_id even when use_service_session is enabled."""
from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub, use_service_session=True)

payload = {
"thread_id": "service-thread-789",
"run_id": "run-3",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
assert stub.last_session.session_id == "service-thread-789"
assert stub.last_session.service_session_id == "service-thread-789"


async def test_session_id_generated_when_no_thread_id():
"""Session gets a generated UUID as session_id when no thread_id is provided."""
import uuid

from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub)

payload = {
"run_id": "run-4",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
# Should be a valid UUID (auto-generated)
uuid.UUID(stub.last_session.session_id)


async def test_service_session_no_thread_id_generates_uuid():
"""With use_service_session=True and no thread_id, session_id is a UUID and service_session_id is None."""
import uuid

from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

stub = StubAgent()
agent = AgentFrameworkAgent(agent=stub, use_service_session=True)

payload = {
"run_id": "run-5",
"messages": [{"role": "user", "content": "Hello"}],
}

_ = [event async for event in agent.run(payload)]

assert stub.last_session is not None
# session_id should be a valid auto-generated UUID
uuid.UUID(stub.last_session.session_id)
# service_session_id should be None since no thread_id was supplied
assert stub.last_session.service_session_id is None
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Loading