@@ -98,30 +98,41 @@ class RunState(Enum):
98
98
def __init__ (self , runtime : SingleThreadedAgentRuntime ) -> None :
99
99
self ._runtime = runtime
100
100
self ._run_state = RunContext .RunState .RUNNING
101
+ self ._end_condition : Callable [[], bool ] = self ._stop_when_cancelled
101
102
self ._run_task = asyncio .create_task (self ._run ())
102
103
self ._lock = asyncio .Lock ()
103
104
104
105
async def _run (self ) -> None :
105
106
while True :
106
107
async with self ._lock :
107
- if self ._run_state == RunContext . RunState . CANCELLED :
108
+ if self ._end_condition () :
108
109
return
109
- elif self ._run_state == RunContext .RunState .UNTIL_IDLE :
110
- if self ._runtime .idle :
111
- return
112
110
113
111
await self ._runtime .process_next ()
114
112
115
113
async def stop (self ) -> None :
116
114
async with self ._lock :
117
115
self ._run_state = RunContext .RunState .CANCELLED
116
+ self ._end_condition = self ._stop_when_cancelled
118
117
await self ._run_task
119
118
120
119
async def stop_when_idle (self ) -> None :
121
120
async with self ._lock :
122
121
self ._run_state = RunContext .RunState .UNTIL_IDLE
122
+ self ._end_condition = self ._stop_when_idle
123
123
await self ._run_task
124
124
125
+ async def stop_when (self , condition : Callable [[], bool ]) -> None :
126
+ async with self ._lock :
127
+ self ._end_condition = condition
128
+ await self ._run_task
129
+
130
+ def _stop_when_cancelled (self ) -> bool :
131
+ return self ._run_state == RunContext .RunState .CANCELLED
132
+
133
+ def _stop_when_idle (self ) -> bool :
134
+ return self ._run_state == RunContext .RunState .UNTIL_IDLE and self ._runtime .idle
135
+
125
136
126
137
class SingleThreadedAgentRuntime (AgentRuntime ):
127
138
def __init__ (self , * , intervention_handlers : List [InterventionHandler ] | None = None ) -> None :
@@ -449,6 +460,13 @@ async def stop_when_idle(self) -> None:
449
460
await self ._run_context .stop_when_idle ()
450
461
self ._run_context = None
451
462
463
+ async def stop_when (self , condition : Callable [[], bool ]) -> None :
464
+ """Stop the runtime message processing loop when the condition is met."""
465
+ if self ._run_context is None :
466
+ raise RuntimeError ("Runtime is not started" )
467
+ await self ._run_context .stop_when (condition )
468
+ self ._run_context = None
469
+
452
470
async def agent_metadata (self , agent : AgentId ) -> AgentMetadata :
453
471
return (await self ._get_agent (agent )).metadata
454
472
0 commit comments