Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def _run_stream_impl(
"""
# Determine the event stream based on whether we have function responses
if bool(self.pending_requests):
# This is a continuation - use send_responses_streaming to send function responses back
# This is a continuation - use run_stream with responses to send function responses back
logger.info(f"Continuing workflow to address {len(self.pending_requests)} requests")

# Extract function responses from input messages, and ensure that
Expand All @@ -212,7 +212,7 @@ async def _run_stream_impl(
# NOTE: It is possible that some pending requests are not fulfilled,
# and we will let the workflow to handle this -- the agent does not
# have an opinion on this.
event_stream = self.workflow.send_responses_streaming(function_responses)
event_stream = self.workflow.run_stream(responses=function_responses)
else:
# Execute workflow with streaming (initial run or no function responses)
# Pass the new input messages directly to the workflow
Expand Down
38 changes: 0 additions & 38 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2458,17 +2458,6 @@ async def _validate_checkpoint_participants(
f"Missing names: {missing}; unexpected names: {unexpected}."
)

async def run_stream_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
responses: dict[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Resume orchestration from a checkpoint and stream resulting events."""
await self._validate_checkpoint_participants(checkpoint_id, checkpoint_storage)
async for event in self._workflow.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses):
yield event

async def run_with_string(self, task_text: str) -> WorkflowRunResult:
"""Run the workflow with a task string and return all events.

Expand Down Expand Up @@ -2512,33 +2501,6 @@ async def run(self, message: Any | None = None) -> WorkflowRunResult:
events.append(event)
return WorkflowRunResult(events)

async def run_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
responses: dict[str, Any] | None = None,
) -> WorkflowRunResult:
"""Resume orchestration from a checkpoint and collect all resulting events."""
events: list[WorkflowEvent] = []
async for event in self.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses):
events.append(event)
return WorkflowRunResult(events)

async def send_responses_streaming(self, responses: dict[str, Any]) -> AsyncIterable[WorkflowEvent]:
"""Forward responses to pending requests and stream resulting events.

This delegates to the underlying Workflow implementation.
"""
async for event in self._workflow.send_responses_streaming(responses):
yield event

async def send_responses(self, responses: dict[str, Any]) -> WorkflowRunResult:
"""Forward responses to pending requests and return all resulting events.

This delegates to the underlying Workflow implementation.
"""
return await self._workflow.send_responses(responses)

def __getattr__(self, name: str) -> Any:
"""Delegate unknown attributes to the underlying workflow."""
return getattr(self._workflow, name)
Expand Down
41 changes: 36 additions & 5 deletions python/packages/core/agent_framework/_workflows/_runner_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def has_checkpointing(self) -> bool:
"""
...

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
...

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
...

# Checkpointing APIs (optional, enabled by storage)
def set_workflow_id(self, workflow_id: str) -> None:
"""Set the workflow ID for the context."""
Expand Down Expand Up @@ -202,6 +214,7 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None):

# Checkpointing configuration/state
self._checkpoint_storage = checkpoint_storage
self._runtime_checkpoint_storage: CheckpointStorage | None = None
self._workflow_id: str | None = None

# Streaming flag - set by workflow's run_stream() vs run()
Expand Down Expand Up @@ -252,16 +265,33 @@ async def next_event(self) -> WorkflowEvent:

# region Checkpointing

def _get_effective_checkpoint_storage(self) -> CheckpointStorage | None:
"""Get the effective checkpoint storage (runtime override or build-time)."""
return self._runtime_checkpoint_storage or self._checkpoint_storage

def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None:
"""Set runtime checkpoint storage to override build-time configuration.

Args:
storage: The checkpoint storage to use for this run.
"""
self._runtime_checkpoint_storage = storage

def clear_runtime_checkpoint_storage(self) -> None:
"""Clear runtime checkpoint storage override."""
self._runtime_checkpoint_storage = None

def has_checkpointing(self) -> bool:
return self._checkpoint_storage is not None
return self._get_effective_checkpoint_storage() is not None

async def create_checkpoint(
self,
shared_state: SharedState,
iteration_count: int,
metadata: dict[str, Any] | None = None,
) -> str:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")

self._workflow_id = self._workflow_id or str(uuid.uuid4())
Expand All @@ -274,14 +304,15 @@ async def create_checkpoint(
iteration_count=state["iteration_count"],
metadata=metadata or {},
)
checkpoint_id = await self._checkpoint_storage.save_checkpoint(checkpoint)
checkpoint_id = await storage.save_checkpoint(checkpoint)
logger.info(f"Created checkpoint {checkpoint_id} for workflow {self._workflow_id}")
return checkpoint_id

async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
if not self._checkpoint_storage:
storage = self._get_effective_checkpoint_storage()
if not storage:
raise ValueError("Checkpoint storage not configured")
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
return await storage.load_checkpoint(checkpoint_id)

def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.
Expand Down
Loading