1
1
from typing import Annotated , Any , Literal
2
- from uuid import UUID
3
2
4
- from beartype import beartype
5
3
from temporalio import workflow
6
4
from temporalio .exceptions import ApplicationError
7
5
10
8
from pydantic_partial import create_partial_model
11
9
12
10
from ...autogen .openapi_model import (
13
- Agent ,
14
- CreateTaskRequest ,
15
11
CreateToolRequest ,
16
12
CreateTransitionRequest ,
17
- Execution ,
18
13
ExecutionStatus ,
19
- PartialTaskSpecDef ,
20
- PatchTaskRequest ,
21
- Session ,
22
- Task ,
23
- TaskSpec ,
24
- TaskSpecDef ,
25
- TaskToolDef ,
26
14
Tool ,
27
15
ToolRef ,
28
16
TransitionTarget ,
29
17
TransitionType ,
30
- UpdateTaskRequest ,
31
- User ,
32
18
Workflow ,
33
19
WorkflowStep ,
34
20
)
35
21
22
+ from ...queries .executions import list_execution_transitions
23
+ from .models import ExecutionInput
24
+
36
25
# TODO: Maybe we should use a library for this
37
26
38
27
# State Machine
@@ -145,23 +134,10 @@ class PartialTransition(create_partial_model(CreateTransitionRequest)):
145
134
user_state : dict [str , Any ] = Field (default_factory = dict )
146
135
147
136
148
- class ExecutionInput (BaseModel ):
149
- developer_id : UUID
150
- execution : Execution | None = None
151
- task : TaskSpecDef | None = None
152
- agent : Agent
153
- agent_tools : list [Tool | CreateToolRequest ]
154
- arguments : dict [str , Any ]
155
-
156
- # Not used at the moment
157
- user : User | None = None
158
- session : Session | None = None
159
-
160
-
161
137
class StepContext (BaseModel ):
162
138
execution_input : ExecutionInput
163
- inputs : list [Any ]
164
139
cursor : TransitionTarget
140
+ current_input : Any
165
141
166
142
@computed_field
167
143
@property
@@ -197,16 +173,6 @@ def tools(self) -> list[Tool | CreateToolRequest]:
197
173
198
174
return filtered_tools + task_tools
199
175
200
- @computed_field
201
- @property
202
- def outputs (self ) -> list [dict [str , Any ]]: # included in dump
203
- return self .inputs [1 :]
204
-
205
- @computed_field
206
- @property
207
- def current_input (self ) -> dict [str , Any ]: # included in dump
208
- return self .inputs [- 1 ]
209
-
210
176
@computed_field
211
177
@property
212
178
def current_workflow (self ) -> Annotated [Workflow , Field (exclude = True )]:
@@ -239,13 +205,22 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:
239
205
240
206
return dump | execution_input
241
207
208
+ async def get_inputs (self ) -> list [Any ]:
209
+ transitions = await list_execution_transitions (
210
+ execution_id = self .execution_input .execution .id ,
211
+ limit = 1000 ,
212
+ direction = "asc" ,
213
+ )
214
+ return [t .output for t in transitions ]
215
+
242
216
async def prepare_for_step (self , * args , ** kwargs ) -> dict [str , Any ]:
243
217
current_input = self .current_input
244
- inputs = self .inputs
218
+ inputs = await self .get_inputs ()
245
219
246
220
# Merge execution inputs into the dump dict
247
221
dump = self .model_dump (* args , ** kwargs )
248
222
dump ["inputs" ] = inputs
223
+ dump ["outputs" ] = inputs [1 :]
249
224
prepared = dump | {"_" : current_input }
250
225
251
226
for i , input in enumerate (inputs ):
@@ -260,70 +235,3 @@ class StepOutcome(BaseModel):
260
235
error : str | None = None
261
236
output : Any = None
262
237
transition_to : tuple [TransitionType , TransitionTarget ] | None = None
263
-
264
-
265
- @beartype
266
- def task_to_spec (
267
- task : Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest , ** model_opts
268
- ) -> TaskSpecDef | PartialTaskSpecDef :
269
- task_data = task .model_dump (
270
- ** model_opts , exclude = {"version" , "developer_id" , "task_id" , "id" , "agent_id" }
271
- )
272
-
273
- if "tools" in task_data :
274
- del task_data ["tools" ]
275
-
276
- tools = []
277
- for tool in task .tools :
278
- tool_spec = getattr (tool , tool .type )
279
-
280
- tool_obj = dict (
281
- type = tool .type ,
282
- spec = tool_spec .model_dump (),
283
- ** tool .model_dump (exclude = {"type" }),
284
- )
285
- tools .append (TaskToolDef (** tool_obj ))
286
-
287
- workflows = [Workflow (name = "main" , steps = task_data .pop ("main" ))]
288
-
289
- for key , steps in list (task_data .items ()):
290
- if key not in TaskSpec .model_fields :
291
- workflows .append (Workflow (name = key , steps = steps ))
292
- del task_data [key ]
293
-
294
- cls = PartialTaskSpecDef if isinstance (task , PatchTaskRequest ) else TaskSpecDef
295
-
296
- return cls (
297
- workflows = workflows ,
298
- tools = tools ,
299
- ** task_data ,
300
- )
301
-
302
-
303
- def spec_to_task_data (spec : dict ) -> dict :
304
- task_id = spec .pop ("task_id" , None )
305
-
306
- workflows = spec .pop ("workflows" )
307
- workflows_dict = {workflow ["name" ]: workflow ["steps" ] for workflow in workflows }
308
-
309
- tools = spec .pop ("tools" , []) or []
310
- tools = [{tool ["type" ]: tool .pop ("spec" ), ** tool } for tool in tools if tool ]
311
-
312
- return {
313
- "id" : task_id ,
314
- "tools" : tools ,
315
- ** spec ,
316
- ** workflows_dict ,
317
- }
318
-
319
-
320
- def spec_to_task (** spec ) -> Task | CreateTaskRequest :
321
- if not spec .get ("id" ):
322
- spec ["id" ] = spec .pop ("task_id" , None )
323
-
324
- if not spec .get ("updated_at" ):
325
- [updated_at_ms , _ ] = spec .pop ("updated_at_ms" , None )
326
- spec ["updated_at" ] = updated_at_ms and (updated_at_ms / 1000 )
327
-
328
- cls = Task if spec ["id" ] else CreateTaskRequest
329
- return cls (** spec_to_task_data (spec ))
0 commit comments