Skip to content

Commit 184e058

Browse files
committed
Add lifecycle hooks
1 parent fb9cb1d commit 184e058

File tree

4 files changed

+205
-8
lines changed

4 files changed

+205
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ All notable changes to Parlant will be documented here.
99
- Add ContextEvaluation in MessageEventGenerator
1010
- Fix ToolCaller false-negative argument validation from int to float
1111
- Fix ToolCaller accuracy
12+
- Add engine lifecycle hooks
1213

1314

1415
## [1.5.1] - 2025-01-05

src/parlant/bin/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import uvicorn
3232

3333
from parlant.adapters.vector_db.chroma import ChromaDatabase
34+
from parlant.core.engines.alpha import hooks
3435
from parlant.core.engines.alpha import guideline_proposer
3536
from parlant.core.engines.alpha import tool_caller
3637
from parlant.core.engines.alpha import message_event_generator
@@ -369,6 +370,8 @@ async def setup_container(nlp_service_name: str) -> AsyncIterator[Container]:
369370
c[ShotCollection[ToolCallerInferenceShot]] = tool_caller.shot_collection
370371
c[ShotCollection[MessageEventGeneratorShot]] = message_event_generator.shot_collection
371372

373+
c[hooks.LifecycleHooks] = hooks.lifecycle_hooks
374+
372375
c[GuidelineProposer] = GuidelineProposer(
373376
c[Logger],
374377
c[SchematicGenerator[GuidelinePropositionsSchema]],

src/parlant/core/engines/alpha/engine.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
Term as StoredTerm,
5151
ToolEventData,
5252
)
53+
from parlant.core.engines.alpha.hooks import lifecycle_hooks
5354
from parlant.core.engines.alpha.guideline_proposer import (
5455
GuidelineProposer,
5556
GuidelinePropositionResult,
@@ -126,14 +127,16 @@ async def process(
126127

127128
self._logger.error(f"Processing error: {formatted_exception}")
128129

129-
await event_emitter.emit_status_event(
130-
correlation_id=self._correlator.correlation_id,
131-
data={
132-
"status": "error",
133-
"acknowledged_offset": interaction_state.last_known_event_offset,
134-
"data": {"exception": formatted_exception},
135-
},
136-
)
130+
if await lifecycle_hooks.call_on_error(context, event_emitter, exc):
131+
await event_emitter.emit_status_event(
132+
correlation_id=self._correlator.correlation_id,
133+
data={
134+
"status": "error",
135+
"acknowledged_offset": interaction_state.last_known_event_offset,
136+
"data": {"exception": formatted_exception},
137+
},
138+
)
139+
137140
return False
138141
except BaseException as exc:
139142
self._logger.critical(f"Critical processing error: {traceback.format_exception(exc)}")
@@ -200,6 +203,9 @@ async def _do_process(
200203
session = await self._session_store.read_session(context.session_id)
201204
customer = await self._customer_store.read_customer(session.customer_id)
202205

206+
if not await lifecycle_hooks.call_on_acknowledging(context, event_emitter):
207+
return
208+
203209
await event_emitter.emit_status_event(
204210
correlation_id=self._correlator.correlation_id,
205211
data={
@@ -209,7 +215,13 @@ async def _do_process(
209215
},
210216
)
211217

218+
if not await lifecycle_hooks.call_on_acknowledged(context, event_emitter):
219+
return
220+
212221
try:
222+
if not await lifecycle_hooks.call_on_preparing(context, event_emitter):
223+
return
224+
213225
context_variables = await self._load_context_variables(
214226
agent_id=context.agent_id, session=session, customer=customer
215227
)
@@ -236,6 +248,11 @@ async def _do_process(
236248
prepared_to_respond = False
237249

238250
while not prepared_to_respond:
251+
if not await lifecycle_hooks.call_on_preparation_iteration_start(
252+
context, event_emitter, all_tool_events
253+
):
254+
break
255+
239256
all_possible_guidelines = await self._guideline_store.list_guidelines(
240257
guideline_set=agent.id,
241258
)
@@ -383,8 +400,28 @@ async def _do_process(
383400
},
384401
)
385402

403+
if not await lifecycle_hooks.call_on_preparation_iteration_end(
404+
context,
405+
event_emitter,
406+
all_tool_events,
407+
[gp.guideline for gp in ordinary_guideline_propositions]
408+
+ [gp.guideline for gp in tool_enabled_guideline_propositions.keys()],
409+
):
410+
break
411+
386412
message_generation_inspections = []
387413

414+
if not await lifecycle_hooks.call_on_generating_messages(
415+
context,
416+
event_emitter,
417+
all_tool_events,
418+
[gp.guideline for gp in ordinary_guideline_propositions]
419+
+ [gp.guideline for gp in tool_enabled_guideline_propositions.keys()],
420+
):
421+
return
422+
423+
all_emitted_events = [*all_tool_events]
424+
388425
for event_generation_result in await self._message_event_generator.generate_events(
389426
event_emitter=event_emitter,
390427
agents=[agent],
@@ -408,13 +445,23 @@ async def _do_process(
408445
)
409446
)
410447

448+
all_emitted_events += [e for e in event_generation_result.events if e]
449+
411450
await self._session_store.create_inspection(
412451
session_id=context.session_id,
413452
correlation_id=self._correlator.correlation_id,
414453
preparation_iterations=preparation_iterations,
415454
message_generations=message_generation_inspections,
416455
)
417456

457+
await lifecycle_hooks.call_on_generated_messages(
458+
context,
459+
event_emitter,
460+
all_emitted_events,
461+
[gp.guideline for gp in ordinary_guideline_propositions]
462+
+ [gp.guideline for gp in tool_enabled_guideline_propositions.keys()],
463+
)
464+
418465
except asyncio.CancelledError:
419466
await event_emitter.emit_status_event(
420467
correlation_id=self._correlator.correlation_id,
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from dataclasses import dataclass, field
2+
from enum import Enum, auto
3+
from typing import Any, Awaitable, Callable, Sequence
4+
5+
from parlant.core.emissions import EmittedEvent, EventEmitter
6+
from parlant.core.engines.types import Context
7+
from parlant.core.guidelines import Guideline
8+
9+
10+
class LifecycleHookResult(Enum):
11+
CALL_NEXT = auto()
12+
"""Runs the next hook in the chain, if any"""
13+
14+
RESOLVE = auto()
15+
"""Returns without running the next hooks in the chain"""
16+
17+
BAIL = auto()
18+
"""Returns without running the next hooks in the chain, and quietly discards the current execution.
19+
20+
For most hooks, this completely bails out of the processing execution, dropping the response to the customer.
21+
Specifically for preparation iterations, this immediately signals that preparation is complete.
22+
"""
23+
24+
25+
@dataclass(frozen=False)
26+
class LifecycleHooks:
27+
on_error: list[Callable[[Context, EventEmitter, Exception], Awaitable[LifecycleHookResult]]] = (
28+
field(default_factory=list)
29+
)
30+
"""Called when the engine has encountered a runtime error"""
31+
32+
on_acknowledging: list[Callable[[Context, EventEmitter], Awaitable[LifecycleHookResult]]] = (
33+
field(default_factory=list)
34+
)
35+
"""Called just before emitting an acknowledgement status event"""
36+
37+
on_acknowledged: list[Callable[[Context, EventEmitter], Awaitable[LifecycleHookResult]]] = (
38+
field(default_factory=list)
39+
)
40+
"""Called right after emitting an acknowledgement status event"""
41+
42+
on_preparing: list[Callable[[Context, EventEmitter], Awaitable[LifecycleHookResult]]] = field(
43+
default_factory=list
44+
)
45+
"""Called just before beginning the preparation iterations"""
46+
47+
on_preparation_iteration_start: list[
48+
Callable[
49+
[Context, EventEmitter, list[EmittedEvent]],
50+
Awaitable[LifecycleHookResult],
51+
]
52+
] = field(default_factory=list)
53+
"""Called just before beginning a preparation iteration"""
54+
55+
on_preparation_iteration_end: list[
56+
Callable[
57+
[Context, EventEmitter, list[EmittedEvent], Sequence[Guideline]],
58+
Awaitable[LifecycleHookResult],
59+
]
60+
] = field(default_factory=list)
61+
"""Called right after finishing a preparation iteration"""
62+
63+
on_generating_messages: list[
64+
Callable[
65+
[Context, EventEmitter, list[EmittedEvent], Sequence[Guideline]],
66+
Awaitable[LifecycleHookResult],
67+
]
68+
] = field(default_factory=list)
69+
"""Called just before generating messages"""
70+
71+
on_generated_messages: list[
72+
Callable[
73+
[Context, EventEmitter, Sequence[EmittedEvent], Sequence[Guideline]],
74+
Awaitable[LifecycleHookResult],
75+
]
76+
] = field(default_factory=list)
77+
"""Called right after generating messages"""
78+
79+
async def call_on_error(
80+
self, context: Context, emitter: EventEmitter, exception: Exception
81+
) -> bool:
82+
return await self._call_hook(self.on_error, context, emitter, exception)
83+
84+
async def call_on_acknowledging(self, context: Context, emitter: EventEmitter) -> bool:
85+
return await self._call_hook(self.on_acknowledging, context, emitter)
86+
87+
async def call_on_acknowledged(self, context: Context, emitter: EventEmitter) -> bool:
88+
return await self._call_hook(self.on_acknowledged, context, emitter)
89+
90+
async def call_on_preparing(self, context: Context, emitter: EventEmitter) -> bool:
91+
return await self._call_hook(self.on_preparing, context, emitter)
92+
93+
async def call_on_preparation_iteration_start(
94+
self,
95+
context: Context,
96+
emitter: EventEmitter,
97+
events: list[EmittedEvent],
98+
) -> bool:
99+
return await self._call_hook(self.on_preparation_iteration_start, context, emitter, events)
100+
101+
async def call_on_preparation_iteration_end(
102+
self,
103+
context: Context,
104+
emitter: EventEmitter,
105+
events: list[EmittedEvent],
106+
guidelines: Sequence[Guideline],
107+
) -> bool:
108+
return await self._call_hook(
109+
self.on_preparation_iteration_end, context, emitter, events, guidelines
110+
)
111+
112+
async def call_on_generating_messages(
113+
self,
114+
context: Context,
115+
emitter: EventEmitter,
116+
events: list[EmittedEvent],
117+
guidelines: Sequence[Guideline],
118+
) -> bool:
119+
return await self._call_hook(
120+
self.on_generating_messages, context, emitter, events, guidelines
121+
)
122+
123+
async def call_on_generated_messages(
124+
self,
125+
context: Context,
126+
emitter: EventEmitter,
127+
events: Sequence[EmittedEvent],
128+
guidelines: Sequence[Guideline],
129+
) -> bool:
130+
return await self._call_hook(
131+
self.on_generated_messages, context, emitter, events, guidelines
132+
)
133+
134+
async def _call_hook(self, hook: Any, *args: Any) -> bool:
135+
for callable in hook:
136+
match await callable(*args):
137+
case LifecycleHookResult.CALL_NEXT:
138+
continue
139+
case LifecycleHookResult.RESOLVE:
140+
return True
141+
case LifecycleHookResult.BAIL:
142+
return False
143+
return True
144+
145+
146+
lifecycle_hooks = LifecycleHooks()

0 commit comments

Comments
 (0)