diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index f092caf0aad9..2d7d6a8d64cf 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Union, cast +from typing import List, Optional, Union, cast from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent from langchain.agents.agent import RunnableAgent @@ -10,8 +10,9 @@ from langflow.base.agents.utils import data_to_messages from langflow.custom import Component from langflow.field_typing import Text -from langflow.inputs.inputs import DataInput, InputTypes +from langflow.inputs.inputs import InputTypes from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput +from langflow.schema import Data from langflow.schema.message import Message from langflow.template import Output from langflow.utils.constants import MESSAGE_SENDER_AI @@ -39,7 +40,6 @@ class LCAgentComponent(Component): value=15, advanced=True, ), - DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True), ] outputs = [ @@ -89,8 +89,13 @@ def get_agent_kwargs(self, flatten: bool = False) -> dict: } return {**base, "agent_executor_kwargs": agent_kwargs} + def get_chat_history_data(self) -> Optional[List[Data]]: + # might be overridden in subclasses + return None + async def run_agent(self, agent: AgentExecutor) -> Text: input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value} + self.chat_history = self.get_chat_history_data() if self.chat_history: input_dict["chat_history"] = data_to_messages(self.chat_history) result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]}) diff --git a/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py b/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py index 0676821da2bf..0864463813fd 100644 --- a/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py +++ b/src/backend/base/langflow/components/agents/OpenAIToolsAgent.py @@ -1,9 +1,12 @@ +from typing import Optional, List + from langchain.agents import create_openai_tools_agent from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate from langflow.base.agents.agent import LCToolsAgentComponent from langflow.inputs import MultilineInput -from langflow.inputs.inputs import HandleInput +from langflow.inputs.inputs import HandleInput, DataInput +from langflow.schema import Data class OpenAIToolsAgentComponent(LCToolsAgentComponent): @@ -29,13 +32,18 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent): MultilineInput( name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}" ), + DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True), ] + def get_chat_history_data(self) -> Optional[List[Data]]: + return self.chat_history + def create_agent_runnable(self): if "input" not in self.user_prompt: raise ValueError("Prompt must contain 'input' key.") messages = [ ("system", self.system_prompt), + ("placeholder", "{chat_history}"), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)), ("placeholder", "{agent_scratchpad}"), ] diff --git a/src/backend/base/langflow/components/agents/ToolCallingAgent.py b/src/backend/base/langflow/components/agents/ToolCallingAgent.py index aa1c4dcd7fcf..8ab91d84cab0 100644 --- a/src/backend/base/langflow/components/agents/ToolCallingAgent.py +++ b/src/backend/base/langflow/components/agents/ToolCallingAgent.py @@ -1,8 +1,11 @@ +from typing import Optional, List + from langchain.agents import create_tool_calling_agent from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate from langflow.base.agents.agent import LCToolsAgentComponent from langflow.inputs import MultilineInput -from langflow.inputs.inputs import HandleInput +from langflow.inputs.inputs import HandleInput, DataInput +from langflow.schema import Data class ToolCallingAgentComponent(LCToolsAgentComponent): @@ -23,13 +26,18 @@ class ToolCallingAgentComponent(LCToolsAgentComponent): MultilineInput( name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}" ), + DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True), ] + def get_chat_history_data(self) -> Optional[List[Data]]: + return self.chat_history + def create_agent_runnable(self): if "input" not in self.user_prompt: raise ValueError("Prompt must contain 'input' key.") messages = [ ("system", self.system_prompt), + ("placeholder", "{chat_history}"), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)), ("placeholder", "{agent_scratchpad}"), ]