diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 654498e371..6bdff552b6 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -86,7 +86,7 @@ def last_message(self): def run_id(self) -> str: """Get or generate run ID.""" if self._run_id is None: - self._run_id = self.input_data.get("run_id") or str(uuid.uuid4()) + self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) # This should never be None after the if block above, but satisfy type checkers if self._run_id is None: # pragma: no cover raise RuntimeError("Failed to initialize run_id") @@ -96,7 +96,7 @@ def run_id(self) -> str: def thread_id(self) -> str: """Get or generate thread ID.""" if self._thread_id is None: - self._thread_id = self.input_data.get("thread_id") or str(uuid.uuid4()) + self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) # This should never be None after the if block above, but satisfy type checkers if self._thread_id is None: # pragma: no cover raise RuntimeError("Failed to initialize thread_id") diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 10843a259c..af90ea2e88 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -83,3 +83,71 @@ async def test_default_orchestrator_merges_client_tools() -> None: assert "server_tool" in tool_names assert "get_weather" in tool_names assert agent.chat_client.function_invocation_configuration.additional_tools + + +async def test_default_orchestrator_with_camel_case_ids() -> None: + """Client tool is able to extract camelCase IDs.""" + + agent = DummyAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "runId": "test-camelcase-runid", + "threadId": "test-camelcase-threadid", + "messages": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + ], + "tools": [], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig(), + ) + + events = [] + async for event in orchestrator.run(context): + events.append(event) + + # assert the last event has the expected run_id and thread_id + last_event = events[-1] + assert last_event.run_id == "test-camelcase-runid" + assert last_event.thread_id == "test-camelcase-threadid" + + +async def test_default_orchestrator_with_snake_case_ids() -> None: + """Client tool is able to extract snake_case IDs.""" + + agent = DummyAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "run_id": "test-snakecase-runid", + "thread_id": "test-snakecase-threadid", + "messages": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + ], + "tools": [], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig(), + ) + + events = [] + async for event in orchestrator.run(context): + events.append(event) + + # assert the last event has the expected run_id and thread_id + last_event = events[-1] + assert last_event.run_id == "test-snakecase-runid" + assert last_event.thread_id == "test-snakecase-threadid"