From 4f4a07f7911a1b1b59a24b83612a1fca87d3c263 Mon Sep 17 00:00:00 2001 From: Haseong Kim Date: Thu, 8 Aug 2024 05:21:38 +0100 Subject: [PATCH] refactor: Allowed RunnableExecutor to stream output and changed its build method to asynchronous. --- .../components/prototypes/RunnableExecutor.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/backend/base/langflow/components/prototypes/RunnableExecutor.py b/src/backend/base/langflow/components/prototypes/RunnableExecutor.py index 8aec1886f95..0e872080afb 100644 --- a/src/backend/base/langflow/components/prototypes/RunnableExecutor.py +++ b/src/backend/base/langflow/components/prototypes/RunnableExecutor.py @@ -1,7 +1,8 @@ from langflow.custom import Component -from langflow.inputs import HandleInput, MessageTextInput +from langflow.inputs import HandleInput, MessageTextInput, BoolInput from langflow.schema.message import Message from langflow.template import Output +from langchain.agents import AgentExecutor class RunnableExecComponent(Component): @@ -30,6 +31,11 @@ class RunnableExecComponent(Component): value="output", advanced=True, ), + BoolInput( + name="use_stream", + display_name="Stream", + value=False, + ), ] outputs = [ @@ -108,11 +114,24 @@ def get_input_dict(self, runnable, input_key, input_value): status = f"Warning: The input key is not '{input_key}'. The input key is '{runnable.input_keys}'." return input_dict, status - def build_executor(self) -> Message: + async def build_executor(self) -> Message: input_dict, status = self.get_input_dict(self.runnable, self.input_key, self.input_value) - result = self.runnable.invoke(input_dict) + if not isinstance(self.runnable, AgentExecutor): + raise ValueError("The runnable must be an AgentExecutor") + + if self.use_stream: + return self.astream_events(input_dict) + else: + result = await self.runnable.ainvoke(input_dict) result_value, _status = self.get_output(result, self.input_key, self.output_key) status += _status status += f"\n\nOutput: {result_value}\n\nRaw Output: {result}" self.status = status return result_value + + async def astream_events(self, input): + async for event in self.runnable.astream_events(input, version="v1"): + if event.get("event") != "on_chat_model_stream": + continue + + yield event.get("data").get("chunk")