Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ async def _handle_mcp_tool_invocation(
# Create or parse session ID
if thread_id and isinstance(thread_id, str) and thread_id.strip():
try:
session_id = AgentSessionId.parse(thread_id)
session_id = AgentSessionId.parse(thread_id, agent_name=agent_name)
except ValueError as e:
logger.warning(
"Failed to parse AgentSessionId from thread_id '%s': %s. Falling back to new session ID.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,34 @@ def __repr__(self) -> str:
return f"AgentSessionId(name='{self.name}', key='{self.key}')"

@staticmethod
def parse(session_id_string: str) -> AgentSessionId:
def parse(session_id_string: str, agent_name: str | None = None) -> AgentSessionId:
"""Parses a string representation of an agent session ID.

Args:
session_id_string: A string in the form @name@key
session_id_string: A string in the form @name@key, or a plain key string
when agent_name is provided.
agent_name: Optional agent name to use instead of parsing from the string.
If provided, only the key portion is extracted from session_id_string
(for @name@key format) or the entire string is used as the key
(for plain strings).

Returns:
AgentSessionId instance

Raises:
ValueError: If the string format is invalid
ValueError: If the string format is invalid and agent_name is not provided
"""
if not session_id_string.startswith("@"):
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
# Check if string is in @name@key format
if session_id_string.startswith("@") and "@" in session_id_string[1:]:
parts = session_id_string[1:].split("@", 1)
name = agent_name if agent_name is not None else parts[0]
return AgentSessionId(name=name, key=parts[1])

parts = session_id_string[1:].split("@", 1)
if len(parts) != 2:
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
# Plain string format - only valid when agent_name is provided
if agent_name is not None:
return AgentSessionId(name=agent_name, key=session_id_string)

return AgentSessionId(name=parts[0], key=parts[1])
raise ValueError(f"Invalid agent session ID format: {session_id_string}")


class DurableAgentThread(AgentThread):
Expand Down
64 changes: 64 additions & 0 deletions python/packages/azurefunctions/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,70 @@ async def test_handle_mcp_tool_invocation_runtime_error(self) -> None:
with pytest.raises(RuntimeError, match="Agent execution failed"):
await app._handle_mcp_tool_invocation("TestAgent", context, client)

async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) -> None:
"""Test that MCP tool invocation uses the agent_name parameter, not the name from thread_id."""
mock_agent = Mock()
mock_agent.name = "PlantAdvisor"

app = AgentFunctionApp(agents=[mock_agent])
client = AsyncMock()

# Mock the entity response
mock_state = Mock()
mock_state.entity_state = {
"schemaVersion": "1.0.0",
"data": {"conversationHistory": []},
}
client.read_entity_state.return_value = mock_state

# Thread ID contains a different agent name (@StockAdvisor@poc123)
# but we're invoking PlantAdvisor - it should use PlantAdvisor's entity
context = json.dumps({"arguments": {"query": "test query", "threadId": "@StockAdvisor@test123"}})

with patch.object(app, "_get_response_from_entity") as get_response_mock:
get_response_mock.return_value = {"status": "success", "response": "Test response"}

await app._handle_mcp_tool_invocation("PlantAdvisor", context, client)

# Verify signal_entity was called with PlantAdvisor's entity, not StockAdvisor's
client.signal_entity.assert_called_once()
call_args = client.signal_entity.call_args
entity_id = call_args[0][0]

# Entity name should be dafx-PlantAdvisor, not dafx-StockAdvisor
assert entity_id.name == "dafx-PlantAdvisor"
assert entity_id.key == "test123"

async def test_handle_mcp_tool_invocation_uses_plain_thread_id_as_key(self) -> None:
"""Test that a plain thread_id (not in @name@key format) is used as-is for the key."""
mock_agent = Mock()
mock_agent.name = "TestAgent"

app = AgentFunctionApp(agents=[mock_agent])
client = AsyncMock()

mock_state = Mock()
mock_state.entity_state = {
"schemaVersion": "1.0.0",
"data": {"conversationHistory": []},
}
client.read_entity_state.return_value = mock_state

# Plain thread_id without @name@key format
context = json.dumps({"arguments": {"query": "test query", "threadId": "simple-thread-123"}})

with patch.object(app, "_get_response_from_entity") as get_response_mock:
get_response_mock.return_value = {"status": "success", "response": "Test response"}

await app._handle_mcp_tool_invocation("TestAgent", context, client)

client.signal_entity.assert_called_once()
call_args = client.signal_entity.call_args
entity_id = call_args[0][0]

assert entity_id.name == "dafx-TestAgent"
assert entity_id.key == "simple-thread-123"

def test_health_check_includes_mcp_tool_enabled(self) -> None:
"""Test that health check endpoint includes mcp_tool_enabled field."""
mock_agent = Mock()
Expand Down
28 changes: 28 additions & 0 deletions python/packages/azurefunctions/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,34 @@ def test_parse_round_trip(self) -> None:
assert parsed.name == original.name
assert parsed.key == original.key

def test_parse_with_agent_name_override(self) -> None:
"""Test parsing @name@key format with agent_name parameter overrides the name."""
session_id = AgentSessionId.parse("@OriginalAgent@test-key-123", agent_name="OverriddenAgent")

assert session_id.name == "OverriddenAgent"
assert session_id.key == "test-key-123"

def test_parse_without_agent_name_uses_parsed_name(self) -> None:
"""Test parsing @name@key format without agent_name uses name from string."""
session_id = AgentSessionId.parse("@ParsedAgent@test-key-123")

assert session_id.name == "ParsedAgent"
assert session_id.key == "test-key-123"

def test_parse_plain_string_with_agent_name(self) -> None:
"""Test parsing plain string with agent_name uses entire string as key."""
session_id = AgentSessionId.parse("simple-thread-123", agent_name="TestAgent")

assert session_id.name == "TestAgent"
assert session_id.key == "simple-thread-123"

def test_parse_plain_string_without_agent_name_raises(self) -> None:
"""Test parsing plain string without agent_name raises ValueError."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("simple-thread-123")

assert "Invalid agent session ID format" in str(exc_info.value)

def test_to_entity_name_adds_prefix(self) -> None:
"""Test that to_entity_name adds the dafx- prefix."""
entity_name = AgentSessionId.to_entity_name("TestAgent")
Expand Down
Loading