diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 8e0a7aec1e..49de54f5d2 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -130,19 +130,57 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: @dataclass class HandoffUserInputRequest: - """Request message emitted when the workflow needs fresh user input.""" + """Request message emitted when the workflow needs fresh user input. + + Note: The conversation field is intentionally excluded from checkpoint serialization + to prevent duplication. The conversation is preserved in the coordinator's state + and will be reconstructed on restore. See issue #2667. + """ conversation: list[ChatMessage] awaiting_agent_id: str prompt: str source_executor_id: str + def to_dict(self) -> dict[str, Any]: + """Serialize to dict, excluding conversation to prevent checkpoint duplication. + + The conversation is already preserved in the workflow coordinator's state. + Including it here would cause duplicate messages when restoring from checkpoint. + """ + return { + "awaiting_agent_id": self.awaiting_agent_id, + "prompt": self.prompt, + "source_executor_id": self.source_executor_id, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest": + """Deserialize from dict, initializing conversation as empty. + + The conversation will be reconstructed from the coordinator's state on restore. + """ + return cls( + conversation=[], + awaiting_agent_id=data["awaiting_agent_id"], + prompt=data["prompt"], + source_executor_id=data["source_executor_id"], + ) + @dataclass class _ConversationWithUserInput: - """Internal message carrying full conversation + new user messages from gateway to coordinator.""" + """Internal message carrying full conversation + new user messages from gateway to coordinator. + + Attributes: + full_conversation: The conversation messages to process. + is_post_restore: If True, indicates this message was created after a checkpoint restore. + The coordinator should append these messages to its existing conversation rather + than replacing it. This prevents duplicate messages (see issue #2667). + """ full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] + is_post_restore: bool = False @dataclass @@ -439,9 +477,25 @@ async def handle_user_input( message: _ConversationWithUserInput, ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]], ) -> None: - """Receive full conversation with new user input from gateway, update history, trim for agent.""" - # Update authoritative conversation - self._conversation = list(message.full_conversation) + """Receive user input from gateway, update history, and route to agent. + + The message.full_conversation may contain: + - Full conversation history + new user messages (normal flow) + - Only new user messages (post-checkpoint-restore flow, see issue #2667) + + The gateway sets message.is_post_restore=True when resuming after a checkpoint + restore. In that case, we append the new messages to the existing conversation + rather than replacing it. + """ + incoming = message.full_conversation + + if message.is_post_restore and self._conversation: + # Post-restore: append new user messages to existing conversation + # The coordinator already has its conversation restored from checkpoint + self._conversation.extend(incoming) + else: + # Normal flow: replace with full conversation + self._conversation = list(incoming) if incoming else self._conversation # Reset autonomous turn counter on new user input self._autonomous_turns = 0 @@ -626,15 +680,24 @@ async def resume_from_user( response: object, ctx: WorkflowContext[_ConversationWithUserInput], ) -> None: - """Convert user input responses back into chat messages and resume the workflow.""" - # Reconstruct full conversation with new user input - conversation = list(original_request.conversation) + """Convert user input responses back into chat messages and resume the workflow. + + After checkpoint restore, original_request.conversation will be empty (not serialized + to prevent duplication - see issue #2667). In this case, we send only the new user + messages and let the coordinator append them to its already-restored conversation. + """ user_messages = _as_user_messages(response) - conversation.extend(user_messages) - # Send full conversation back to coordinator (not trimmed) - # Coordinator will update its authoritative history and trim for agent - message = _ConversationWithUserInput(full_conversation=conversation) + if original_request.conversation: + # Normal flow: have conversation history from the original request + conversation = list(original_request.conversation) + conversation.extend(user_messages) + message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False) + else: + # Post-restore flow: conversation was not serialized, send only new user messages + # The coordinator will append these to its already-restored conversation + message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True) + await ctx.send_message(message, target_id="handoff-coordinator") diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index 0ceccfaf15..077cb7321e 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -25,7 +25,12 @@ from agent_framework._mcp import MCPTool from agent_framework._workflows import AgentRunEvent from agent_framework._workflows import _handoff as handoff_module # type: ignore -from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage] +from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from agent_framework._workflows._handoff import ( + _clone_chat_agent, # type: ignore[reportPrivateUsage] + _ConversationWithUserInput, + _UserInputGateway, +) from agent_framework._workflows._workflow_builder import WorkflowBuilder @@ -775,3 +780,402 @@ async def test_return_to_previous_state_serialization(): # Verify current_agent_id was restored assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] + + +async def test_handoff_user_input_request_checkpoint_excludes_conversation(): + """Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication. + + Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest, + the conversation field gets serialized twice: once in the RequestInfoEvent's data + and once in the coordinator's conversation state. On restore, this causes duplicate + messages. + + The fix is to exclude the conversation field during checkpoint serialization since + the conversation is already preserved in the coordinator's state. + """ + # Create a conversation history + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ] + + # Create a HandoffUserInputRequest with the conversation + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist_agent", + prompt="Please provide your input", + source_executor_id="gateway", + ) + + # Encode the request (simulating checkpoint save) + encoded = encode_checkpoint_value(request) + + # Verify conversation is NOT in the encoded output + # The fix should exclude conversation from serialization + assert isinstance(encoded, dict) + + # If using MODEL_MARKER strategy (to_dict/from_dict) + if "__af_model__" in encoded or "__af_dataclass__" in encoded: + value = encoded.get("value", {}) + assert "conversation" not in value, "conversation should be excluded from checkpoint serialization" + + # Decode the request (simulating checkpoint restore) + decoded = decode_checkpoint_value(encoded) + + # Verify the decoded request is a HandoffUserInputRequest + assert isinstance(decoded, HandoffUserInputRequest) + + # Verify other fields are preserved + assert decoded.awaiting_agent_id == "specialist_agent" + assert decoded.prompt == "Please provide your input" + assert decoded.source_executor_id == "gateway" + + # Conversation should be an empty list after deserialization + # (will be reconstructed from coordinator state on restore) + assert decoded.conversation == [] + + +async def test_handoff_user_input_request_roundtrip_preserves_metadata(): + """Test that non-conversation fields survive checkpoint roundtrip.""" + request = HandoffUserInputRequest( + conversation=[ChatMessage(role=Role.USER, text="test")], + awaiting_agent_id="test_agent", + prompt="Enter your response", + source_executor_id="test_gateway", + ) + + # Roundtrip through checkpoint encoding + encoded = encode_checkpoint_value(request) + decoded = decode_checkpoint_value(encoded) + + assert isinstance(decoded, HandoffUserInputRequest) + assert decoded.awaiting_agent_id == request.awaiting_agent_id + assert decoded.prompt == request.prompt + assert decoded.source_executor_id == request.source_executor_id + + +async def test_request_info_event_with_handoff_user_input_request(): + """Test RequestInfoEvent serialization with HandoffUserInputRequest data.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="How can I help?"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="specialist", + prompt="Provide input", + source_executor_id="gateway", + ) + + # Create a RequestInfoEvent wrapping the request + event = RequestInfoEvent( + request_id="test-request-123", + source_executor_id="gateway", + request_data=request, + response_type=object, + ) + + # Serialize the event + event_dict = event.to_dict() + + # Verify the data field doesn't contain conversation + data_encoded = event_dict["data"] + if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded): + value = data_encoded.get("value", {}) + assert "conversation" not in value + + # Deserialize and verify + restored_event = RequestInfoEvent.from_dict(event_dict) + assert isinstance(restored_event.data, HandoffUserInputRequest) + assert restored_event.data.awaiting_agent_id == "specialist" + assert restored_event.data.conversation == [] + + +async def test_handoff_user_input_request_to_dict_excludes_conversation(): + """Test that to_dict() method excludes conversation field.""" + conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ] + + request = HandoffUserInputRequest( + conversation=conversation, + awaiting_agent_id="agent1", + prompt="Enter input", + source_executor_id="gateway", + ) + + # Call to_dict directly + data = request.to_dict() + + # Verify conversation is excluded + assert "conversation" not in data + assert data["awaiting_agent_id"] == "agent1" + assert data["prompt"] == "Enter input" + assert data["source_executor_id"] == "gateway" + + +async def test_handoff_user_input_request_from_dict_creates_empty_conversation(): + """Test that from_dict() creates an instance with empty conversation.""" + data = { + "awaiting_agent_id": "agent1", + "prompt": "Enter input", + "source_executor_id": "gateway", + } + + request = HandoffUserInputRequest.from_dict(data) + + assert request.conversation == [] + assert request.awaiting_agent_id == "agent1" + assert request.prompt == "Enter input" + assert request.source_executor_id == "gateway" + + +async def test_user_input_gateway_resume_handles_empty_conversation(): + """Test that _UserInputGateway.resume_from_user handles post-restore scenario. + + After checkpoint restore, the HandoffUserInputRequest will have an empty + conversation. The gateway should handle this by sending only the new user + messages to the coordinator. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Simulate post-restore: request with empty conversation + restored_request = HandoffUserInputRequest( + conversation=[], # Empty after restore + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(restored_request, "New user message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains only the new user message (not any history) + assert len(sent_message.full_conversation) == 1 + assert sent_message.full_conversation[0].role == Role.USER + assert sent_message.full_conversation[0].text == "New user message" + + +async def test_user_input_gateway_resume_with_full_conversation(): + """Test that _UserInputGateway.resume_from_user handles normal flow correctly. + + In normal flow (no checkpoint restore), the HandoffUserInputRequest has + the full conversation. The gateway should send the full conversation + plus the new user messages. + """ + from unittest.mock import AsyncMock + + # Create a gateway + gateway = _UserInputGateway( + starting_agent_id="coordinator", + prompt="Enter input", + id="test-gateway", + ) + + # Normal flow: request with full conversation + normal_request = HandoffUserInputRequest( + conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ], + awaiting_agent_id="specialist", + prompt="Enter input", + source_executor_id="test-gateway", + ) + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Call resume_from_user with a user response + await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx) + + # Verify send_message was called + mock_ctx.send_message.assert_called_once() + + # Get the message that was sent + call_args = mock_ctx.send_message.call_args + sent_message = call_args[0][0] + + # Verify it's a _ConversationWithUserInput + assert isinstance(sent_message, _ConversationWithUserInput) + + # Verify it contains the full conversation plus new user message + assert len(sent_message.full_conversation) == 3 + assert sent_message.full_conversation[0].text == "Hello" + assert sent_message.full_conversation[1].text == "Hi!" + assert sent_message.full_conversation[2].text == "Follow up message" + + +async def test_coordinator_handle_user_input_post_restore(): + """Test that _HandoffCoordinator.handle_user_input handles post-restore correctly. + + After checkpoint restore, the coordinator has its conversation restored, + and the gateway sends only the new user messages. The coordinator should + append these to its existing conversation rather than replacing. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with pre-existing conversation (simulating restored state) + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Simulate restored conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), + ChatMessage(role=Role.USER, text="Help me"), + ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Simulate post-restore: only new user message with explicit flag + incoming = _ConversationWithUserInput( + full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")], + is_post_restore=True, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was appended, not replaced + assert len(coordinator._conversation) == 5 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi there!" + assert coordinator._conversation[2].text == "Help me" + assert coordinator._conversation[3].text == "Sure, what do you need?" + assert coordinator._conversation[4].text == "I need shipping help" + + +async def test_coordinator_handle_user_input_normal_flow(): + """Test that _HandoffCoordinator.handle_user_input handles normal flow correctly. + + In normal flow (no restore), the gateway sends the full conversation. + The coordinator should replace its conversation with the incoming one. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set some initial conversation + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Old message"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: full conversation including new user message (is_post_restore=False by default) + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi!"), + ChatMessage(role=Role.USER, text="New message"), + ], + is_post_restore=False, + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was replaced (normal flow with full history) + assert len(coordinator._conversation) == 3 + assert coordinator._conversation[0].text == "Hello" + assert coordinator._conversation[1].text == "Hi!" + assert coordinator._conversation[2].text == "New message" + + +async def test_coordinator_handle_user_input_multiple_consecutive_user_messages(): + """Test that multiple consecutive USER messages in normal flow are handled correctly. + + This is a regression test for the edge case where a user submits multiple consecutive + USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly + detected as a post-restore scenario. + """ + from unittest.mock import AsyncMock + + from agent_framework._workflows._handoff import _HandoffCoordinator + + # Create a coordinator with existing conversation + coordinator = _HandoffCoordinator( + starting_agent_id="triage", + specialist_ids={"specialist_a": "specialist_a"}, + input_gateway_id="gateway", + termination_condition=lambda conv: False, + id="test-coordinator", + ) + + # Set existing conversation with 4 messages + coordinator._conversation = [ + ChatMessage(role=Role.USER, text="Original message 1"), + ChatMessage(role=Role.ASSISTANT, text="Response 1"), + ChatMessage(role=Role.USER, text="Original message 2"), + ChatMessage(role=Role.ASSISTANT, text="Response 2"), + ] + + # Create mock context + mock_ctx = MagicMock() + mock_ctx.send_message = AsyncMock() + + # Normal flow: User sends multiple consecutive USER messages + # This should REPLACE the conversation, not append to it + incoming = _ConversationWithUserInput( + full_conversation=[ + ChatMessage(role=Role.USER, text="New user message 1"), + ChatMessage(role=Role.USER, text="New user message 2"), + ], + is_post_restore=False, # Explicit flag - this is normal flow + ) + + # Handle the user input + await coordinator.handle_user_input(incoming, mock_ctx) + + # Verify conversation was REPLACED (not appended) + # Without the explicit flag, the old heuristic might incorrectly append + assert len(coordinator._conversation) == 2 + assert coordinator._conversation[0].text == "New user message 1" + assert coordinator._conversation[1].text == "New user message 2" diff --git a/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py b/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py index 8981c1fddd..d9a996e807 100644 --- a/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py +++ b/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py @@ -122,11 +122,17 @@ def _print_handoff_request(request: HandoffUserInputRequest, request_id: str) -> print(f"Awaiting agent: {request.awaiting_agent_id}") print(f"Prompt: {request.prompt}") - print("\nConversation so far:") - for msg in request.conversation[-3:]: - author = msg.author_name or msg.role.value - snippet = msg.text[:120] + "..." if len(msg.text) > 120 else msg.text - print(f" {author}: {snippet}") + # Note: After checkpoint restore, conversation may be empty because it's not serialized + # to prevent duplication (the conversation is preserved in the coordinator's state). + # See issue #2667. + if request.conversation: + print("\nConversation so far:") + for msg in request.conversation[-3:]: + author = msg.author_name or msg.role.value + snippet = msg.text[:120] + "..." if len(msg.text) > 120 else msg.text + print(f" {author}: {snippet}") + else: + print("\n(Conversation restored from checkpoint - context preserved in workflow state)") print(f"{'=' * 60}\n") @@ -273,11 +279,7 @@ async def resume_with_responses( elif isinstance(event, WorkflowOutputEvent): print("\n[Workflow Output Event - Conversation Update]") - if ( - event.data - and isinstance(event.data, list) - and all(isinstance(msg, ChatMessage) for msg in event.data) - ): + if event.data and isinstance(event.data, list) and all(isinstance(msg, ChatMessage) for msg in event.data): # Now safe to cast event.data to list[ChatMessage] conversation = cast(list[ChatMessage], event.data) for msg in conversation[-3:]: # Show last 3 messages