Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,21 @@ async def atrace_as_chain_group(
await run_manager.on_chain_end({})


Func = TypeVar("Func", bound=Callable)


def shielded(func: Func) -> Func:
"""
Makes so an awaitable method is always shielded from cancellation
"""

@functools.wraps(func)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
return await asyncio.shield(func(*args, **kwargs))

return cast(Func, wrapped)


def handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
Expand Down Expand Up @@ -293,7 +308,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
with asyncio.Runner() as runner:
# Run the coroutine, get the result
for coro in coros:
runner.run(coro)
try:
runner.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")

# Run pending tasks scheduled by coros until they are all done
while pending := asyncio.all_tasks(runner.get_loop()):
Expand All @@ -302,7 +320,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
for coro in coros:
asyncio.run(coro)
try:
asyncio.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")


async def _ahandle_event_for_handler(
Expand Down Expand Up @@ -682,6 +703,7 @@ def get_sync(self) -> CallbackManagerForLLMRun:
inheritable_metadata=self.inheritable_metadata,
)

@shielded
async def on_llm_new_token(
self,
token: str,
Expand All @@ -706,6 +728,7 @@ async def on_llm_new_token(
**kwargs,
)

@shielded
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.

Expand All @@ -723,6 +746,7 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
**kwargs,
)

@shielded
async def on_llm_error(
self,
error: BaseException,
Expand Down Expand Up @@ -853,6 +877,7 @@ def get_sync(self) -> CallbackManagerForChainRun:
inheritable_metadata=self.inheritable_metadata,
)

@shielded
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
Expand All @@ -872,6 +897,7 @@ async def on_chain_end(
**kwargs,
)

@shielded
async def on_chain_error(
self,
error: BaseException,
Expand All @@ -893,6 +919,7 @@ async def on_chain_error(
**kwargs,
)

@shielded
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received.

Expand All @@ -913,6 +940,7 @@ async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
**kwargs,
)

@shielded
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received.

Expand Down Expand Up @@ -1000,6 +1028,7 @@ def get_sync(self) -> CallbackManagerForToolRun:
inheritable_metadata=self.inheritable_metadata,
)

@shielded
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.

Expand All @@ -1017,6 +1046,7 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> None:
**kwargs,
)

@shielded
async def on_tool_error(
self,
error: BaseException,
Expand Down Expand Up @@ -1100,6 +1130,7 @@ def get_sync(self) -> CallbackManagerForRetrieverRun:
inheritable_metadata=self.inheritable_metadata,
)

@shielded
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
Expand All @@ -1115,6 +1146,7 @@ async def on_retriever_end(
**kwargs,
)

@shielded
async def on_retriever_error(
self,
error: BaseException,
Expand Down