Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ def from_dict(cls, payload: dict[str, Any]) -> "WorkflowAgent.RequestInfoFunctio

@classmethod
def from_json(cls, raw: str) -> "WorkflowAgent.RequestInfoFunctionArgs":
data = json.loads(raw)
if not isinstance(data, dict):
try:
parsed: Any = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(f"RequestInfoFunctionArgs JSON payload is malformed: {exc}") from exc
if not isinstance(parsed, dict):
raise ValueError("RequestInfoFunctionArgs JSON payload must decode to a mapping")
return cls.from_dict(data)
return cls.from_dict(cast(dict[str, Any], parsed))

def __init__(
self,
Expand Down
21 changes: 13 additions & 8 deletions python/packages/core/agent_framework/_workflows/_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,15 @@ async def handle_agent_response(

if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
await ctx.yield_output(list(conversation))
# Clean the output conversation for display
cleaned_output = clean_conversation_for_handoff(conversation)
await ctx.yield_output(cleaned_output)
return

await ctx.send_message(list(conversation), target_id=self._input_gateway_id)
# Clean conversation before sending to gateway for user input request
# This removes tool messages that shouldn't be shown to users
cleaned_for_display = clean_conversation_for_handoff(conversation)
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)

@handler
async def handle_user_input(
Expand Down Expand Up @@ -1274,12 +1279,12 @@ def build(self) -> Workflow:
updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map)
self._executors[source_exec_id] = updated_executor
handoff_tool_targets.update(tool_targets)
else:
# Default behavior: only coordinator gets handoff tools to all specialists
if isinstance(starting_executor, AgentExecutor) and specialists:
starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists)
self._executors[self._starting_agent_id] = starting_executor
handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications
else:
# Default behavior: only coordinator gets handoff tools to all specialists
if isinstance(starting_executor, AgentExecutor) and specialists:
starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists)
self._executors[self._starting_agent_id] = starting_executor
handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications
starting_executor = self._executors[self._starting_agent_id]
specialists = {
exec_id: executor for exec_id, executor in self._executors.items() if exec_id != self._starting_agent_id
Expand Down
36 changes: 0 additions & 36 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,16 +2442,6 @@ async def _validate_checkpoint_participants(
f"Missing names: {missing}; unexpected names: {unexpected}."
)

async def run_stream_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Resume orchestration from a checkpoint and stream resulting events."""
await self._validate_checkpoint_participants(checkpoint_id, checkpoint_storage)
async for event in self._workflow.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage):
yield event

async def run_with_string(self, task_text: str) -> WorkflowRunResult:
"""Run the workflow with a task string and return all events.

Expand Down Expand Up @@ -2495,32 +2485,6 @@ async def run(self, message: Any | None = None) -> WorkflowRunResult:
events.append(event)
return WorkflowRunResult(events)

async def run_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
) -> WorkflowRunResult:
"""Resume orchestration from a checkpoint and collect all resulting events."""
events: list[WorkflowEvent] = []
async for event in self.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage):
events.append(event)
return WorkflowRunResult(events)

async def send_responses_streaming(self, responses: dict[str, Any]) -> AsyncIterable[WorkflowEvent]:
"""Forward responses to pending requests and stream resulting events.

This delegates to the underlying Workflow implementation.
"""
async for event in self._workflow.send_responses_streaming(responses):
yield event

async def send_responses(self, responses: dict[str, Any]) -> WorkflowRunResult:
"""Forward responses to pending requests and return all resulting events.

This delegates to the underlying Workflow implementation.
"""
return await self._workflow.send_responses(responses)

def __getattr__(self, name: str) -> Any:
"""Delegate unknown attributes to the underlying workflow."""
return getattr(self._workflow, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat

# Has tool content - only keep if it also has text
if msg.text and msg.text.strip():
# Create fresh text-only message
# Create fresh text-only message while preserving additional_properties
msg_copy = ChatMessage(
role=msg.role,
text=msg.text,
author_name=msg.author_name,
additional_properties=dict(msg.additional_properties) if msg.additional_properties else None,
)
cleaned.append(msg_copy)

Expand Down
46 changes: 41 additions & 5 deletions python/packages/core/agent_framework/_workflows/_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ def has_checkpointing(self) -> bool:
"""
...

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
...

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
...

# Checkpointing APIs (optional, enabled by storage)
def set_workflow_id(self, workflow_id: str) -> None:
"""Set the workflow ID for the context."""
Expand Down Expand Up @@ -279,6 +291,7 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None):

# Checkpointing configuration/state
self._checkpoint_storage = checkpoint_storage
self._runtime_checkpoint_storage: CheckpointStorage | None = None
self._workflow_id: str | None = None

# Streaming flag - set by workflow's run_stream() vs run()
Expand Down Expand Up @@ -329,16 +342,37 @@ async def next_event(self) -> WorkflowEvent:

# region Checkpointing

def _get_effective_checkpoint_storage(self) -> CheckpointStorage | None:
"""Get the effective checkpoint storage (runtime override or build-time)."""
return self._runtime_checkpoint_storage or self._checkpoint_storage

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
self._runtime_checkpoint_storage = storage

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override.

This is called automatically by workflow execution methods after a run completes,
ensuring runtime storage doesn't leak across runs.
"""
self._runtime_checkpoint_storage = None

def has_checkpointing(self) -> bool:
return self._checkpoint_storage is not None
return self._get_effective_checkpoint_storage() is not None

async def create_checkpoint(
self,
shared_state: SharedState,
iteration_count: int,
metadata: dict[str, Any] | None = None,
) -> str:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")

self._workflow_id = self._workflow_id or str(uuid.uuid4())
Expand All @@ -352,19 +386,21 @@ async def create_checkpoint(
iteration_count=state["iteration_count"],
metadata=metadata or {},
)
checkpoint_id = await self._checkpoint_storage.save_checkpoint(checkpoint)
checkpoint_id = await storage.save_checkpoint(checkpoint)
logger.info(f"Created checkpoint {checkpoint_id} for workflow {self._workflow_id}")
return checkpoint_id

async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
return await storage.load_checkpoint(checkpoint_id)

def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.

This clears messages, events, and resets streaming flag.
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
"""
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
Expand Down
95 changes: 0 additions & 95 deletions python/packages/core/agent_framework/_workflows/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

logger = logging.getLogger(__name__)

# Track cycle signatures we've already reported to avoid spamming logs when workflows
# with intentional feedback loops are constructed multiple times in the same process.
_LOGGED_CYCLE_SIGNATURES: set[tuple[str, ...]] = set()


# region Enums and Base Classes
class ValidationTypeEnum(Enum):
Expand Down Expand Up @@ -168,7 +164,6 @@ def validate_workflow(
self._validate_graph_connectivity(start_executor_id)
self._validate_self_loops()
self._validate_dead_ends()
self._validate_cycles()

def _validate_handler_output_annotations(self) -> None:
"""Validate that each handler's ctx parameter is annotated with WorkflowContext[T].
Expand Down Expand Up @@ -394,96 +389,6 @@ def _validate_dead_ends(self) -> None:
f"Verify these are intended as final nodes in the workflow."
)

def _validate_cycles(self) -> None:
"""Detect cycles in the workflow graph.

Cycles might be intentional for iterative processing but should be flagged
for review to ensure proper termination conditions exist. We surface each
distinct cycle group only once per process to avoid noisy, repeated warnings
when rebuilding the same workflow.
"""
# Build adjacency list (ensure every executor appears even if it has no outgoing edges)
graph: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
graph[edge.source_id].append(edge.target_id)
graph.setdefault(edge.target_id, [])
for executor_id in self._executors:
graph.setdefault(executor_id, [])

# Tarjan's algorithm to locate strongly-connected components that form cycles
index: dict[str, int] = {}
lowlink: dict[str, int] = {}
on_stack: set[str] = set()
stack: list[str] = []
current_index = 0
cycle_components: list[list[str]] = []

def strongconnect(node: str) -> None:
nonlocal current_index

index[node] = current_index
lowlink[node] = current_index
current_index += 1
stack.append(node)
on_stack.add(node)

for neighbor in graph[node]:
if neighbor not in index:
strongconnect(neighbor)
lowlink[node] = min(lowlink[node], lowlink[neighbor])
elif neighbor in on_stack:
lowlink[node] = min(lowlink[node], index[neighbor])

if lowlink[node] == index[node]:
component: list[str] = []
while True:
member = stack.pop()
on_stack.discard(member)
component.append(member)
if member == node:
break

# A strongly connected component represents a cycle if it has more than one
# node or if a single node references itself directly.
if len(component) > 1 or any(member in graph[member] for member in component):
cycle_components.append(component)

for executor_id in graph:
if executor_id not in index:
strongconnect(executor_id)

if not cycle_components:
return

unseen_components: list[list[str]] = []
for component in cycle_components:
signature = tuple(sorted(component))
if signature in _LOGGED_CYCLE_SIGNATURES:
continue
_LOGGED_CYCLE_SIGNATURES.add(signature)
unseen_components.append(component)

if not unseen_components:
# All cycles already reported in this process; keep noise low but retain traceability.
logger.debug(
"Cycle detected in workflow graph but previously reported. Components: %s",
[sorted(component) for component in cycle_components],
)
return

def _format_cycle(component: list[str]) -> str:
if not component:
return ""
ordered = list(component)
ordered.append(component[0])
return " -> ".join(ordered)

formatted_cycles = ", ".join(_format_cycle(component) for component in unseen_components)
logger.warning(
"Cycle detected in the workflow graph involving: %s. Ensure termination or iteration limits exist.",
formatted_cycles,
)

# endregion


Expand Down
Loading