diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 29e4a7df6a..5e64a3feaf 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -609,7 +609,7 @@ async def _handle_mcp_tool_invocation( # Create or parse session ID if thread_id and isinstance(thread_id, str) and thread_id.strip(): try: - session_id = AgentSessionId.parse(thread_id) + session_id = AgentSessionId.parse(thread_id, agent_name=agent_name) except ValueError as e: logger.warning( "Failed to parse AgentSessionId from thread_id '%s': %s. Falling back to new session ID.", diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 2ab9667575..ffee3b77fe 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -109,26 +109,34 @@ def __repr__(self) -> str: return f"AgentSessionId(name='{self.name}', key='{self.key}')" @staticmethod - def parse(session_id_string: str) -> AgentSessionId: + def parse(session_id_string: str, agent_name: str | None = None) -> AgentSessionId: """Parses a string representation of an agent session ID. Args: - session_id_string: A string in the form @name@key + session_id_string: A string in the form @name@key, or a plain key string + when agent_name is provided. + agent_name: Optional agent name to use instead of parsing from the string. + If provided, only the key portion is extracted from session_id_string + (for @name@key format) or the entire string is used as the key + (for plain strings). Returns: AgentSessionId instance Raises: - ValueError: If the string format is invalid + ValueError: If the string format is invalid and agent_name is not provided """ - if not session_id_string.startswith("@"): - raise ValueError(f"Invalid agent session ID format: {session_id_string}") + # Check if string is in @name@key format + if session_id_string.startswith("@") and "@" in session_id_string[1:]: + parts = session_id_string[1:].split("@", 1) + name = agent_name if agent_name is not None else parts[0] + return AgentSessionId(name=name, key=parts[1]) - parts = session_id_string[1:].split("@", 1) - if len(parts) != 2: - raise ValueError(f"Invalid agent session ID format: {session_id_string}") + # Plain string format - only valid when agent_name is provided + if agent_name is not None: + return AgentSessionId(name=agent_name, key=session_id_string) - return AgentSessionId(name=parts[0], key=parts[1]) + raise ValueError(f"Invalid agent session ID format: {session_id_string}") class DurableAgentThread(AgentThread): diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 29d614e729..b4b0428f43 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1056,6 +1056,70 @@ async def test_handle_mcp_tool_invocation_runtime_error(self) -> None: with pytest.raises(RuntimeError, match="Agent execution failed"): await app._handle_mcp_tool_invocation("TestAgent", context, client) + async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) -> None: + """Test that MCP tool invocation uses the agent_name parameter, not the name from thread_id.""" + mock_agent = Mock() + mock_agent.name = "PlantAdvisor" + + app = AgentFunctionApp(agents=[mock_agent]) + client = AsyncMock() + + # Mock the entity response + mock_state = Mock() + mock_state.entity_state = { + "schemaVersion": "1.0.0", + "data": {"conversationHistory": []}, + } + client.read_entity_state.return_value = mock_state + + # Thread ID contains a different agent name (@StockAdvisor@poc123) + # but we're invoking PlantAdvisor - it should use PlantAdvisor's entity + context = json.dumps({"arguments": {"query": "test query", "threadId": "@StockAdvisor@test123"}}) + + with patch.object(app, "_get_response_from_entity") as get_response_mock: + get_response_mock.return_value = {"status": "success", "response": "Test response"} + + await app._handle_mcp_tool_invocation("PlantAdvisor", context, client) + + # Verify signal_entity was called with PlantAdvisor's entity, not StockAdvisor's + client.signal_entity.assert_called_once() + call_args = client.signal_entity.call_args + entity_id = call_args[0][0] + + # Entity name should be dafx-PlantAdvisor, not dafx-StockAdvisor + assert entity_id.name == "dafx-PlantAdvisor" + assert entity_id.key == "test123" + + async def test_handle_mcp_tool_invocation_uses_plain_thread_id_as_key(self) -> None: + """Test that a plain thread_id (not in @name@key format) is used as-is for the key.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent]) + client = AsyncMock() + + mock_state = Mock() + mock_state.entity_state = { + "schemaVersion": "1.0.0", + "data": {"conversationHistory": []}, + } + client.read_entity_state.return_value = mock_state + + # Plain thread_id without @name@key format + context = json.dumps({"arguments": {"query": "test query", "threadId": "simple-thread-123"}}) + + with patch.object(app, "_get_response_from_entity") as get_response_mock: + get_response_mock.return_value = {"status": "success", "response": "Test response"} + + await app._handle_mcp_tool_invocation("TestAgent", context, client) + + client.signal_entity.assert_called_once() + call_args = client.signal_entity.call_args + entity_id = call_args[0][0] + + assert entity_id.name == "dafx-TestAgent" + assert entity_id.key == "simple-thread-123" + def test_health_check_includes_mcp_tool_enabled(self) -> None: """Test that health check endpoint includes mcp_tool_enabled field.""" mock_agent = Mock() diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py index 74efa9c166..be31f59800 100644 --- a/python/packages/azurefunctions/tests/test_models.py +++ b/python/packages/azurefunctions/tests/test_models.py @@ -120,6 +120,34 @@ def test_parse_round_trip(self) -> None: assert parsed.name == original.name assert parsed.key == original.key + def test_parse_with_agent_name_override(self) -> None: + """Test parsing @name@key format with agent_name parameter overrides the name.""" + session_id = AgentSessionId.parse("@OriginalAgent@test-key-123", agent_name="OverriddenAgent") + + assert session_id.name == "OverriddenAgent" + assert session_id.key == "test-key-123" + + def test_parse_without_agent_name_uses_parsed_name(self) -> None: + """Test parsing @name@key format without agent_name uses name from string.""" + session_id = AgentSessionId.parse("@ParsedAgent@test-key-123") + + assert session_id.name == "ParsedAgent" + assert session_id.key == "test-key-123" + + def test_parse_plain_string_with_agent_name(self) -> None: + """Test parsing plain string with agent_name uses entire string as key.""" + session_id = AgentSessionId.parse("simple-thread-123", agent_name="TestAgent") + + assert session_id.name == "TestAgent" + assert session_id.key == "simple-thread-123" + + def test_parse_plain_string_without_agent_name_raises(self) -> None: + """Test parsing plain string without agent_name raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("simple-thread-123") + + assert "Invalid agent session ID format" in str(exc_info.value) + def test_to_entity_name_adds_prefix(self) -> None: """Test that to_entity_name adds the dafx- prefix.""" entity_name = AgentSessionId.to_entity_name("TestAgent")