Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/packages/core/agent_framework/_workflows/_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,15 @@ async def handle_agent_response(

if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
await ctx.yield_output(list(conversation))
# Clean the output conversation for display
cleaned_output = clean_conversation_for_handoff(conversation)
await ctx.yield_output(cleaned_output)
return

await ctx.send_message(list(conversation), target_id=self._input_gateway_id)
# Clean conversation before sending to gateway for user input request
# This removes tool messages that shouldn't be shown to users
cleaned_for_display = clean_conversation_for_handoff(conversation)
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)

@handler
async def handle_user_input(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat

# Has tool content - only keep if it also has text
if msg.text and msg.text.strip():
# Create fresh text-only message
# Create fresh text-only message while preserving additional_properties
msg_copy = ChatMessage(
role=msg.role,
text=msg.text,
author_name=msg.author_name,
additional_properties=dict(msg.additional_properties) if msg.additional_properties else None,
)
cleaned.append(msg_copy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,11 @@ def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
self._runtime_checkpoint_storage = storage

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
"""Clear runtime checkpoint storage override.

This is called automatically by workflow execution methods after a run completes,
ensuring runtime storage doesn't leak across runs.
"""
self._runtime_checkpoint_storage = None

def has_checkpointing(self) -> bool:
Expand Down Expand Up @@ -396,6 +400,7 @@ def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.

This clears messages, events, and resets streaming flag.
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
"""
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
Expand Down
118 changes: 47 additions & 71 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Workflow(DictConvertible):
2. Executor implements `response_handler()` to process the response
3. Requests are emitted as RequestInfoEvent instances in the event stream
4. Workflow enters IDLE_WITH_PENDING_REQUESTS state
5. Caller handles requests and provides responses via the executor's handle_response method
5. Caller handles requests and provides responses via the `send_responses` or `send_responses_streaming` methods
6. Responses are routed to the requesting executors and response handlers are invoked

## Checkpointing
Expand Down Expand Up @@ -384,6 +384,46 @@ async def _run_workflow_with_tracing(
capture_exception(span, exception=exc)
raise

async def _execute_with_message_or_checkpoint(
self,
message: Any | None,
checkpoint_id: str | None,
checkpoint_storage: CheckpointStorage | None,
) -> None:
"""Internal handler for executing workflow with either initial message or checkpoint restoration.

Args:
message: Initial message for the start executor (for new runs).
checkpoint_id: ID of checkpoint to restore from (for resuming runs).
checkpoint_storage: Runtime checkpoint storage.
"""
# Handle checkpoint restoration
if checkpoint_id is not None:
has_checkpointing = self._runner.context.has_checkpointing()

if not has_checkpointing and checkpoint_storage is None:
raise ValueError(
"Cannot restore from checkpoint: either provide checkpoint_storage parameter "
"or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)."
)

restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)

if not restored:
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")

# Handle initial message
elif message is not None:
executor = self.get_start_executor()
await executor.execute(
message,
[self.__class__.__name__],
self._shared_state,
self._runner.context,
trace_contexts=None,
source_span_ids=None,
)

async def run_stream(
self,
message: Any | None = None,
Expand Down Expand Up @@ -463,50 +503,18 @@ async def run_stream(
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)

try:

async def execution_handler() -> None:
# Handle checkpoint restoration
if checkpoint_id is not None:
has_checkpointing = self._runner.context.has_checkpointing()

if not has_checkpointing and checkpoint_storage is None:
raise ValueError(
"Cannot restore from checkpoint: either provide checkpoint_storage parameter "
"or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)."
)

restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)

if not restored:
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")

# Process pending messages from checkpoint
if await self._runner.context.has_messages():
await self._runner._run_iteration() # type: ignore

# Handle initial message
elif message is not None:
executor = self.get_start_executor()
await executor.execute(
message,
[self.__class__.__name__],
self._shared_state,
self._runner.context,
trace_contexts=None,
source_span_ids=None,
)

# Reset context only for new runs (not checkpoint restoration)
reset_context = message is not None and checkpoint_id is None

async for event in self._run_workflow_with_tracing(
initial_executor_fn=execution_handler,
initial_executor_fn=functools.partial(
self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage
),
reset_context=reset_context,
streaming=True,
):
yield event
finally:
# Clear runtime checkpoint storage after run completes
if checkpoint_storage is not None:
self._runner.context.clear_runtime_checkpoint_storage()
self._reset_running_flag()
Expand Down Expand Up @@ -607,51 +615,19 @@ async def run(
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)

try:

async def execution_handler() -> None:
# Handle checkpoint restoration
if checkpoint_id is not None:
has_checkpointing = self._runner.context.has_checkpointing()

if not has_checkpointing and checkpoint_storage is None:
raise ValueError(
"Cannot restore from checkpoint: either provide checkpoint_storage parameter "
"or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)."
)

restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)

if not restored:
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")

# Process pending messages from checkpoint
if await self._runner.context.has_messages():
await self._runner._run_iteration() # type: ignore

# Handle initial message
elif message is not None:
executor = self.get_start_executor()
await executor.execute(
message,
[self.__class__.__name__],
self._shared_state,
self._runner.context,
trace_contexts=None,
source_span_ids=None,
)

# Reset context only for new runs (not checkpoint restoration)
reset_context = message is not None and checkpoint_id is None

raw_events = [
event
async for event in self._run_workflow_with_tracing(
initial_executor_fn=execution_handler,
initial_executor_fn=functools.partial(
self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage
),
reset_context=reset_context,
)
]
finally:
# Clear runtime checkpoint storage after run completes
if checkpoint_storage is not None:
self._runner.context.clear_runtime_checkpoint_storage()
self._reset_running_flag()
Expand Down
42 changes: 7 additions & 35 deletions python/packages/core/tests/workflow/test_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,33 +155,6 @@ async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
return [event async for event in stream]


@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_handoff_routes_to_specialist_and_requests_user_input():
triage = _RecordingAgent(name="triage", handoff_to="specialist")
specialist = _RecordingAgent(name="specialist")

workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator("triage").build()

events = await _drain(workflow.run_stream("Need help with a refund"))

assert triage.calls, "Starting agent should receive initial conversation"
assert specialist.calls, "Specialist should be invoked after handoff"
assert len(specialist.calls[0]) == 2 # user + triage reply

requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests, "Workflow should request additional user input"
request_payload = requests[-1].data
assert isinstance(request_payload, HandoffUserInputRequest)
assert len(request_payload.conversation) == 4 # user, triage tool call, tool ack, specialist
assert request_payload.conversation[2].role == Role.TOOL
assert request_payload.conversation[3].role == Role.ASSISTANT
assert "specialist reply" in request_payload.conversation[3].text

follow_up = await _drain(workflow.run_stream(responses={requests[-1].request_id: "Thanks"}))
assert any(isinstance(ev, RequestInfoEvent) for ev in follow_up)


@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_specialist_to_specialist_handoff():
"""Test that specialists can hand off to other specialists via .add_handoff() configuration."""
triage = _RecordingAgent(name="triage", handoff_to="specialist")
Expand All @@ -206,15 +179,14 @@ async def test_specialist_to_specialist_handoff():
assert len(specialist.calls) > 0

# Second user message - specialist hands off to escalation
events = await _drain(workflow.run_stream(responses={requests[-1].request_id: "This is complex"}))
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "This is complex"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs

# Escalation should have been called
assert len(escalation.calls) > 0


@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_handoff_preserves_complex_additional_properties(complex_metadata: _ComplexMetadata):
triage = _RecordingAgent(name="triage", handoff_to="specialist", extra_properties={"complex": complex_metadata})
specialist = _RecordingAgent(name="specialist")
Expand Down Expand Up @@ -259,7 +231,9 @@ async def test_handoff_preserves_complex_additional_properties(complex_metadata:
assert restored_meta.payload["code"] == "X1"

# Respond and ensure metadata survives subsequent cycles
follow_up_events = await _drain(workflow.run_stream(responses={requests[-1].request_id: "Here are more details"}))
follow_up_events = await _drain(
workflow.send_responses_streaming({requests[-1].request_id: "Here are more details"})
)
follow_up_requests = [ev for ev in follow_up_events if isinstance(ev, RequestInfoEvent)]
outputs = [ev for ev in follow_up_events if isinstance(ev, WorkflowOutputEvent)]

Expand Down Expand Up @@ -310,7 +284,6 @@ def test_build_fails_without_participants():
HandoffBuilder().build()


@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_multiple_runs_dont_leak_conversation():
"""Verify that running the same workflow multiple times doesn't leak conversation history."""
triage = _RecordingAgent(name="triage", handoff_to="specialist")
Expand All @@ -327,7 +300,7 @@ async def test_multiple_runs_dont_leak_conversation():
events = await _drain(workflow.run_stream("First run message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.run_stream(responses={requests[-1].request_id: "Second message"}))
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "First run should emit output"

Expand All @@ -345,7 +318,7 @@ async def test_multiple_runs_dont_leak_conversation():
events = await _drain(workflow.run_stream("Second run different message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.run_stream(responses={requests[-1].request_id: "Another message"}))
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "Second run should emit output"

Expand All @@ -362,7 +335,6 @@ async def test_multiple_runs_dont_leak_conversation():
)


@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_handoff_async_termination_condition() -> None:
"""Test that async termination conditions work correctly."""
termination_call_count = 0
Expand All @@ -386,7 +358,7 @@ async def async_termination(conv: list[ChatMessage]) -> bool:
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests

events = await _drain(workflow.run_stream(responses={requests[-1].request_id: "Second user message"}))
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second user message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert len(outputs) == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,21 @@ async def main() -> None:

pending_responses: dict[str, str] | None = None
completed = False
initial_run = True

while not completed:
last_executor: str | None = None
stream = (
workflow.run_stream(responses=pending_responses)
if pending_responses is not None
else workflow.run_stream(
if initial_run:
stream = workflow.run_stream(
"Create a short launch blurb for the LumenX desk lamp. Emphasize adjustability and warm lighting."
)
)
pending_responses = None
initial_run = False
elif pending_responses is not None:
stream = workflow.send_responses_streaming(pending_responses)
pending_responses = None
else:
break

requests: list[tuple[str, DraftFeedbackRequest]] = []

async for event in stream:
Expand Down