Skip to content

Commit 79cdc71

Browse files
Merge pull request #1078 from julep-ai/f/workflow-refactoring
Task execution workflow refactoring
2 parents 73cb9ec + d4a3311 commit 79cdc71

19 files changed

+1091
-561
lines changed

agents-api/agents_api/clients/temporal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def run_task_execution_workflow(
120120

121121
return await client.start_workflow(
122122
TaskExecutionWorkflow.run,
123-
args=[execution_input, start, previous_inputs],
123+
args=[execution_input, start, previous_inputs[-1]],
124124
task_queue=temporal_task_queue,
125125
id=str(job_id),
126126
run_timeout=timedelta(days=31),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Any
2+
from uuid import UUID
3+
4+
from beartype import beartype
5+
from pydantic import BaseModel
6+
7+
from ...autogen.openapi_model import (
8+
Agent,
9+
CreateTaskRequest,
10+
CreateToolRequest,
11+
Execution,
12+
PartialTaskSpecDef,
13+
PatchTaskRequest,
14+
Session,
15+
Task,
16+
TaskSpec,
17+
TaskSpecDef,
18+
TaskToolDef,
19+
Tool,
20+
UpdateTaskRequest,
21+
User,
22+
Workflow,
23+
)
24+
25+
26+
class ExecutionInput(BaseModel):
27+
developer_id: UUID
28+
execution: Execution | None = None
29+
task: TaskSpecDef | None = None
30+
agent: Agent
31+
agent_tools: list[Tool | CreateToolRequest]
32+
arguments: dict[str, Any]
33+
34+
# Not used at the moment
35+
user: User | None = None
36+
session: Session | None = None
37+
38+
39+
@beartype
40+
def task_to_spec(
41+
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
42+
) -> TaskSpecDef | PartialTaskSpecDef:
43+
task_data = task.model_dump(
44+
**model_opts, exclude={"version", "developer_id", "task_id", "id", "agent_id"}
45+
)
46+
47+
if "tools" in task_data:
48+
del task_data["tools"]
49+
50+
tools = []
51+
for tool in task.tools:
52+
tool_spec = getattr(tool, tool.type)
53+
54+
tool_obj = dict(
55+
type=tool.type,
56+
spec=tool_spec.model_dump(),
57+
**tool.model_dump(exclude={"type"}),
58+
)
59+
tools.append(TaskToolDef(**tool_obj))
60+
61+
workflows = [Workflow(name="main", steps=task_data.pop("main"))]
62+
63+
for key, steps in list(task_data.items()):
64+
if key not in TaskSpec.model_fields:
65+
workflows.append(Workflow(name=key, steps=steps))
66+
del task_data[key]
67+
68+
cls = PartialTaskSpecDef if isinstance(task, PatchTaskRequest) else TaskSpecDef
69+
70+
return cls(
71+
workflows=workflows,
72+
tools=tools,
73+
**task_data,
74+
)
75+
76+
77+
def spec_to_task_data(spec: dict) -> dict:
78+
task_id = spec.pop("task_id", None)
79+
80+
workflows = spec.pop("workflows")
81+
workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows}
82+
83+
tools = spec.pop("tools", []) or []
84+
tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool]
85+
86+
return {
87+
"id": task_id,
88+
"tools": tools,
89+
**spec,
90+
**workflows_dict,
91+
}
92+
93+
94+
def spec_to_task(**spec) -> Task | CreateTaskRequest:
95+
if not spec.get("id"):
96+
spec["id"] = spec.pop("task_id", None)
97+
98+
if not spec.get("updated_at"):
99+
[updated_at_ms, _] = spec.pop("updated_at_ms", None)
100+
spec["updated_at"] = updated_at_ms and (updated_at_ms / 1000)
101+
102+
cls = Task if spec["id"] else CreateTaskRequest
103+
return cls(**spec_to_task_data(spec))

agents-api/agents_api/common/protocol/tasks.py

+14-106
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Annotated, Any, Literal
2-
from uuid import UUID
32

4-
from beartype import beartype
53
from temporalio import workflow
64
from temporalio.exceptions import ApplicationError
75

@@ -10,29 +8,20 @@
108
from pydantic_partial import create_partial_model
119

1210
from ...autogen.openapi_model import (
13-
Agent,
14-
CreateTaskRequest,
1511
CreateToolRequest,
1612
CreateTransitionRequest,
17-
Execution,
1813
ExecutionStatus,
19-
PartialTaskSpecDef,
20-
PatchTaskRequest,
21-
Session,
22-
Task,
23-
TaskSpec,
24-
TaskSpecDef,
25-
TaskToolDef,
2614
Tool,
2715
ToolRef,
2816
TransitionTarget,
2917
TransitionType,
30-
UpdateTaskRequest,
31-
User,
3218
Workflow,
3319
WorkflowStep,
3420
)
3521

22+
from ...queries.executions import list_execution_transitions
23+
from .models import ExecutionInput
24+
3625
# TODO: Maybe we should use a library for this
3726

3827
# State Machine
@@ -145,23 +134,10 @@ class PartialTransition(create_partial_model(CreateTransitionRequest)):
145134
user_state: dict[str, Any] = Field(default_factory=dict)
146135

147136

148-
class ExecutionInput(BaseModel):
149-
developer_id: UUID
150-
execution: Execution | None = None
151-
task: TaskSpecDef | None = None
152-
agent: Agent
153-
agent_tools: list[Tool | CreateToolRequest]
154-
arguments: dict[str, Any]
155-
156-
# Not used at the moment
157-
user: User | None = None
158-
session: Session | None = None
159-
160-
161137
class StepContext(BaseModel):
162138
execution_input: ExecutionInput
163-
inputs: list[Any]
164139
cursor: TransitionTarget
140+
current_input: Any
165141

166142
@computed_field
167143
@property
@@ -197,16 +173,6 @@ def tools(self) -> list[Tool | CreateToolRequest]:
197173

198174
return filtered_tools + task_tools
199175

200-
@computed_field
201-
@property
202-
def outputs(self) -> list[dict[str, Any]]: # included in dump
203-
return self.inputs[1:]
204-
205-
@computed_field
206-
@property
207-
def current_input(self) -> dict[str, Any]: # included in dump
208-
return self.inputs[-1]
209-
210176
@computed_field
211177
@property
212178
def current_workflow(self) -> Annotated[Workflow, Field(exclude=True)]:
@@ -239,13 +205,22 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:
239205

240206
return dump | execution_input
241207

208+
async def get_inputs(self) -> list[Any]:
209+
transitions = await list_execution_transitions(
210+
execution_id=self.execution_input.execution.id,
211+
limit=1000,
212+
direction="asc",
213+
)
214+
return [t.output for t in transitions]
215+
242216
async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
243217
current_input = self.current_input
244-
inputs = self.inputs
218+
inputs = await self.get_inputs()
245219

246220
# Merge execution inputs into the dump dict
247221
dump = self.model_dump(*args, **kwargs)
248222
dump["inputs"] = inputs
223+
dump["outputs"] = inputs[1:]
249224
prepared = dump | {"_": current_input}
250225

251226
for i, input in enumerate(inputs):
@@ -260,70 +235,3 @@ class StepOutcome(BaseModel):
260235
error: str | None = None
261236
output: Any = None
262237
transition_to: tuple[TransitionType, TransitionTarget] | None = None
263-
264-
265-
@beartype
266-
def task_to_spec(
267-
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
268-
) -> TaskSpecDef | PartialTaskSpecDef:
269-
task_data = task.model_dump(
270-
**model_opts, exclude={"version", "developer_id", "task_id", "id", "agent_id"}
271-
)
272-
273-
if "tools" in task_data:
274-
del task_data["tools"]
275-
276-
tools = []
277-
for tool in task.tools:
278-
tool_spec = getattr(tool, tool.type)
279-
280-
tool_obj = dict(
281-
type=tool.type,
282-
spec=tool_spec.model_dump(),
283-
**tool.model_dump(exclude={"type"}),
284-
)
285-
tools.append(TaskToolDef(**tool_obj))
286-
287-
workflows = [Workflow(name="main", steps=task_data.pop("main"))]
288-
289-
for key, steps in list(task_data.items()):
290-
if key not in TaskSpec.model_fields:
291-
workflows.append(Workflow(name=key, steps=steps))
292-
del task_data[key]
293-
294-
cls = PartialTaskSpecDef if isinstance(task, PatchTaskRequest) else TaskSpecDef
295-
296-
return cls(
297-
workflows=workflows,
298-
tools=tools,
299-
**task_data,
300-
)
301-
302-
303-
def spec_to_task_data(spec: dict) -> dict:
304-
task_id = spec.pop("task_id", None)
305-
306-
workflows = spec.pop("workflows")
307-
workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows}
308-
309-
tools = spec.pop("tools", []) or []
310-
tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool]
311-
312-
return {
313-
"id": task_id,
314-
"tools": tools,
315-
**spec,
316-
**workflows_dict,
317-
}
318-
319-
320-
def spec_to_task(**spec) -> Task | CreateTaskRequest:
321-
if not spec.get("id"):
322-
spec["id"] = spec.pop("task_id", None)
323-
324-
if not spec.get("updated_at"):
325-
[updated_at_ms, _] = spec.pop("updated_at_ms", None)
326-
spec["updated_at"] = updated_at_ms and (updated_at_ms / 1000)
327-
328-
cls = Task if spec["id"] else CreateTaskRequest
329-
return cls(**spec_to_task_data(spec))

agents-api/agents_api/queries/executions/prepare_execution_input.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from beartype import beartype
44

5-
from ...common.protocol.tasks import ExecutionInput
5+
from ...common.protocol.models import ExecutionInput
66
from ...common.utils.db_exceptions import common_db_exceptions
77
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
88

agents-api/agents_api/queries/tasks/create_or_update_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from uuid_extensions import uuid7
66

77
from ...autogen.openapi_model import CreateOrUpdateTaskRequest, ResourceUpdatedResponse
8-
from ...common.protocol.tasks import task_to_spec
8+
from ...common.protocol.models import task_to_spec
99
from ...common.utils.db_exceptions import common_db_exceptions
1010
from ...metrics.counters import increase_counter
1111
from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class

agents-api/agents_api/queries/tasks/create_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from uuid_extensions import uuid7
66

77
from ...autogen.openapi_model import CreateTaskRequest, ResourceCreatedResponse
8-
from ...common.protocol.tasks import task_to_spec
8+
from ...common.protocol.models import task_to_spec
99
from ...common.utils.db_exceptions import common_db_exceptions
1010
from ...metrics.counters import increase_counter
1111
from ..utils import (

agents-api/agents_api/queries/tasks/get_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from beartype import beartype
55

6-
from ...common.protocol.tasks import spec_to_task
6+
from ...common.protocol.models import spec_to_task
77
from ...common.utils.db_exceptions import common_db_exceptions
88
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
99

agents-api/agents_api/queries/tasks/list_tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from beartype import beartype
55
from fastapi import HTTPException
66

7-
from ...common.protocol.tasks import spec_to_task
7+
from ...common.protocol.models import spec_to_task
88
from ...common.utils.db_exceptions import common_db_exceptions
99
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
1010

agents-api/agents_api/queries/tasks/patch_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from beartype import beartype
55

66
from ...autogen.openapi_model import PatchTaskRequest, ResourceUpdatedResponse
7-
from ...common.protocol.tasks import task_to_spec
7+
from ...common.protocol.models import task_to_spec
88
from ...common.utils.datetime import utcnow
99
from ...common.utils.db_exceptions import common_db_exceptions
1010
from ...metrics.counters import increase_counter

agents-api/agents_api/queries/tasks/update_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from beartype import beartype
55

66
from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateTaskRequest
7-
from ...common.protocol.tasks import task_to_spec
7+
from ...common.protocol.models import task_to_spec
88
from ...common.utils.datetime import utcnow
99
from ...common.utils.db_exceptions import common_db_exceptions
1010
from ...metrics.counters import increase_counter

agents-api/agents_api/routers/tasks/create_task_execution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from ...clients.temporal import run_task_execution_workflow
2121
from ...common.protocol.developers import Developer
22-
from ...common.protocol.tasks import task_to_spec
22+
from ...common.protocol.models import task_to_spec
2323
from ...dependencies.developer_id import get_developer_id
2424
from ...env import max_free_executions
2525
from ...queries.developers.get_developer import get_developer

agents-api/agents_api/routers/utils/__init__.py

Whitespace-only changes.

agents-api/agents_api/worker/worker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def create_worker(client: Client) -> Any:
2121

2222
from ..activities import task_steps
2323
from ..activities.demo import demo_activity
24-
from ..activities.excecute_api_call import execute_api_call
24+
from ..activities.execute_api_call import execute_api_call
2525
from ..activities.execute_integration import execute_integration
2626
from ..activities.execute_system import execute_system
2727
from ..activities.sync_items_remote import load_inputs_remote, save_inputs_remote

0 commit comments

Comments
 (0)