16
16
HybridDocSearchRequest ,
17
17
SystemDef ,
18
18
TextOnlyDocSearchRequest ,
19
+ UpdateSessionRequest ,
19
20
VectorDocSearchRequest ,
20
21
)
21
- from ..common .protocol .tasks import StepContext
22
+ from ..common .protocol .tasks import ExecutionInput , StepContext
22
23
from ..common .storage_handler import auto_blob_store , load_from_blob_store_if_remote
23
24
from ..env import testing
24
25
from ..models .developer import get_developer
@@ -40,6 +41,9 @@ async def execute_system(
40
41
if set (arguments .keys ()) == {"bucket" , "key" }:
41
42
arguments = await load_from_blob_store_if_remote (arguments )
42
43
44
+ if not isinstance (context .execution_input , ExecutionInput ):
45
+ raise TypeError ("Expected ExecutionInput type for context.execution_input" )
46
+
43
47
arguments ["developer_id" ] = context .execution_input .developer_id
44
48
45
49
# Unbox all the arguments
@@ -106,10 +110,11 @@ async def execute_system(
106
110
await bg_runner ()
107
111
return res
108
112
113
+ # Handle create operations
109
114
if system .operation == "create" and system .resource == "session" :
110
115
developer_id = arguments .pop ("developer_id" )
111
116
session_id = arguments .pop ("session_id" , None )
112
- data = CreateSessionRequest (** arguments )
117
+ create_session_request = CreateSessionRequest (** arguments )
113
118
114
119
# In case sessions.create becomes asynchronous in the future
115
120
if asyncio .iscoroutinefunction (handler ):
@@ -118,7 +123,35 @@ async def execute_system(
118
123
# Run the synchronous function in another process
119
124
loop = asyncio .get_running_loop ()
120
125
return await loop .run_in_executor (
121
- process_pool_executor , partial (handler , developer_id , session_id , data )
126
+ process_pool_executor ,
127
+ partial (
128
+ handler ,
129
+ developer_id = developer_id ,
130
+ session_id = session_id ,
131
+ data = create_session_request ,
132
+ ),
133
+ )
134
+
135
+ # Handle update operations
136
+ if system .operation == "update" and system .resource == "session" :
137
+ developer_id = arguments .pop ("developer_id" )
138
+ session_id = arguments .pop ("session_id" )
139
+ update_session_request = UpdateSessionRequest (** arguments )
140
+
141
+ # In case sessions.update becomes asynchronous in the future
142
+ if asyncio .iscoroutinefunction (handler ):
143
+ return await handler ()
144
+
145
+ # Run the synchronous function in another process
146
+ loop = asyncio .get_running_loop ()
147
+ return await loop .run_in_executor (
148
+ process_pool_executor ,
149
+ partial (
150
+ handler ,
151
+ developer_id = developer_id ,
152
+ session_id = session_id ,
153
+ data = update_session_request ,
154
+ ),
122
155
)
123
156
124
157
# Handle regular operations
0 commit comments