Skip to content

Commit de88a33

Browse files
committed
fix(agents-api): Fix session.create & Add session.update system tools
1 parent f8514a2 commit de88a33

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

Diff for: agents-api/agents_api/activities/execute_system.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
HybridDocSearchRequest,
1717
SystemDef,
1818
TextOnlyDocSearchRequest,
19+
UpdateSessionRequest,
1920
VectorDocSearchRequest,
2021
)
21-
from ..common.protocol.tasks import StepContext
22+
from ..common.protocol.tasks import ExecutionInput, StepContext
2223
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
2324
from ..env import testing
2425
from ..models.developer import get_developer
@@ -40,6 +41,10 @@ async def execute_system(
4041
if set(arguments.keys()) == {"bucket", "key"}:
4142
arguments = await load_from_blob_store_if_remote(arguments)
4243

44+
if not isinstance(context.execution_input, ExecutionInput):
45+
raise TypeError(
46+
"Expected ExecutionInput type for context.execution_input")
47+
4348
arguments["developer_id"] = context.execution_input.developer_id
4449

4550
# Unbox all the arguments
@@ -91,7 +96,8 @@ async def execute_system(
9196

9297
# Handle chat operations
9398
if system.operation == "chat" and system.resource == "session":
94-
developer = get_developer(developer_id=arguments.get("developer_id"))
99+
developer = get_developer(
100+
developer_id=arguments.get("developer_id"))
95101
session_id = arguments.get("session_id")
96102
x_custom_api_key = arguments.get("x_custom_api_key", None)
97103
chat_input = ChatInput(**arguments)
@@ -106,10 +112,11 @@ async def execute_system(
106112
await bg_runner()
107113
return res
108114

115+
# Handle create operations
109116
if system.operation == "create" and system.resource == "session":
110117
developer_id = arguments.pop("developer_id")
111118
session_id = arguments.pop("session_id", None)
112-
data = CreateSessionRequest(**arguments)
119+
create_session_request = CreateSessionRequest(**arguments)
113120

114121
# In case sessions.create becomes asynchronous in the future
115122
if asyncio.iscoroutinefunction(handler):
@@ -118,7 +125,35 @@ async def execute_system(
118125
# Run the synchronous function in another process
119126
loop = asyncio.get_running_loop()
120127
return await loop.run_in_executor(
121-
process_pool_executor, partial(handler, developer_id, session_id, data)
128+
process_pool_executor,
129+
partial(
130+
handler,
131+
developer_id=developer_id,
132+
session_id=session_id,
133+
data=create_session_request,
134+
),
135+
)
136+
137+
# Handle update operations
138+
if system.operation == "update" and system.resource == "session":
139+
developer_id = arguments.pop("developer_id")
140+
session_id = arguments.pop("session_id")
141+
update_session_request = UpdateSessionRequest(**arguments)
142+
143+
# In case sessions.update becomes asynchronous in the future
144+
if asyncio.iscoroutinefunction(handler):
145+
return await handler()
146+
147+
# Run the synchronous function in another process
148+
loop = asyncio.get_running_loop()
149+
return await loop.run_in_executor(
150+
process_pool_executor,
151+
partial(
152+
handler,
153+
developer_id=developer_id,
154+
session_id=session_id,
155+
data=update_session_request,
156+
),
122157
)
123158

124159
# Handle regular operations

0 commit comments

Comments
 (0)