Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update inputs in ToolCallingAgentComponent and add astream_events setup #4240

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 91 additions & 9 deletions src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, cast
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, cast

from fastapi.encoders import jsonable_encoder
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain.agents.agent import RunnableAgent
from langchain_core.runnables import Runnable
Expand All @@ -12,6 +14,7 @@
from langflow.inputs.inputs import InputTypes
from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
from langflow.schema import Data
from langflow.schema.log import LogFunctionType
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI
Expand Down Expand Up @@ -109,7 +112,26 @@ async def run_agent(self, agent: AgentExecutor) -> Text:
msg = "Output key not found in result. Tried 'output'."
raise ValueError(msg)

return cast(str, result.get("output"))
return cast(str, result)

async def handle_chain_start(self, event: dict[str, Any]) -> None:
if event["name"] == "Agent":
self.log(f"Starting agent: {event['name']} with input: {event['data'].get('input')}")

async def handle_chain_end(self, event: dict[str, Any]) -> None:
if event["name"] == "Agent":
self.log(f"Done agent: {event['name']} with output: {event['data'].get('output', {}).get('output', '')}")

async def handle_tool_start(self, event: dict[str, Any]) -> None:
self.log(f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}")

async def handle_tool_end(self, event: dict[str, Any]) -> None:
self.log(f"Done tool: {event['name']}")
self.log(f"Tool output was: {event['data'].get('output')}")

@abstractmethod
def create_agent_runnable(self) -> Runnable:
"""Create the agent."""


class LCToolsAgentComponent(LCAgentComponent):
Expand Down Expand Up @@ -146,16 +168,76 @@ async def run_agent(
if self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)

result = runnable.invoke(
input_dict, config={"callbacks": [AgentAsyncHandler(self.log), *self.get_langchain_callbacks()]}
result = await process_agent_events(
runnable.astream_events(
input_dict,
config={"callbacks": [AgentAsyncHandler(self.log), *self.get_langchain_callbacks()]},
version="v2",
),
self.log,
)
self.status = result
if "output" not in result:
msg = "Output key not found in result. Tried 'output'."
raise ValueError(msg)

return cast(str, result.get("output"))
self.status = result
return cast(str, result)

@abstractmethod
def create_agent_runnable(self) -> Runnable:
"""Create the agent."""


# Add this function near the top of the file, after the imports


async def process_agent_events(agent_executor: AsyncIterator[dict[str, Any]], log_callback: LogFunctionType) -> str:
"""Process agent events and return the final output.

Args:
agent_executor: An async iterator of agent events
log_callback: A callable function for logging messages

Returns:
str: The final output from the agent
"""
final_output = ""
async for event in agent_executor:
match event["event"]:
case "on_chain_start":
if event["data"].get("input"):
log_callback(f"Agent initiated with input: {event['data'].get('input')}", name="🚀 Agent Start")

case "on_chain_end":
data_output = event["data"].get("output", {})
if data_output and "output" in data_output:
final_output = data_output["output"]
log_callback(f"{final_output}", name="✅ Agent End")
elif data_output and "agent_scratchpad" in data_output and data_output["agent_scratchpad"]:
agent_scratchpad_messages = data_output["agent_scratchpad"]
json_encoded_messages = jsonable_encoder(agent_scratchpad_messages)
log_callback(json_encoded_messages, name="🔍 Agent Scratchpad")

case "on_tool_start":
log_callback(
f"Initiating tool: '{event['name']}' with inputs: {event['data'].get('input')}",
name="🔧 Tool Start",
)

case "on_tool_end":
log_callback(f"Tool '{event['name']}' execution completed", name="🏁 Tool End")
log_callback(f"{event['data'].get('output')}", name="📊 Tool Output")

case "on_tool_error":
tool_name = event.get("name", "Unknown tool")
error_message = event["data"].get("error", "Unknown error")
log_callback(f"Tool '{tool_name}' failed with error: {error_message}", name="❌ Tool Error")

if "stack_trace" in event["data"]:
log_callback(f"{event['data']['stack_trace']}", name="🔍 Tool Error")

if "recovery_attempt" in event["data"]:
log_callback(f"{event['data']['recovery_attempt']}", name="🔄 Tool Error")

case _:
# Handle any other event types or ignore them
pass

return final_output
29 changes: 14 additions & 15 deletions src/backend/base/langflow/components/agents/tool_calling.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,49 @@
from langchain.agents import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate
from langchain_core.prompts import ChatPromptTemplate

from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.inputs import MultilineInput
from langflow.inputs import MessageTextInput
from langflow.inputs.inputs import DataInput, HandleInput
from langflow.schema import Data


class ToolCallingAgentComponent(LCToolsAgentComponent):
display_name: str = "Tool Calling Agent"
description: str = "Agent that uses tools"
description: str = "An agent designed to utilize various tools seamlessly within workflows."
icon = "LangChain"
beta = True
name = "ToolCallingAgent"

inputs = [
*LCToolsAgentComponent._base_inputs,
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
MultilineInput(
MessageTextInput(
name="system_prompt",
display_name="System Prompt",
info="System prompt for the agent.",
value="You are a helpful assistant",
info="Initial instructions and context provided to guide the agent's behavior.",
value="You are a helpful assistant that can use tools to answer questions and perform tasks.",
),
MultilineInput(
name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}"
MessageTextInput(
name="input_value",
display_name="Input",
info="The input provided by the user for the agent to process.",
),
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
DataInput(name="chat_history", display_name="Chat Memory", is_list=True, advanced=True),
]

def get_chat_history_data(self) -> list[Data] | None:
return self.chat_history

def create_agent_runnable(self):
if "input" not in self.user_prompt:
msg = "Prompt must contain 'input' key."
raise ValueError(msg)
messages = [
("system", self.system_prompt),
("placeholder", "{chat_history}"),
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)),
("human", self.input_value),
("placeholder", "{agent_scratchpad}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
try:
return create_tool_calling_agent(self.llm, self.tools, prompt)
return create_tool_calling_agent(self.llm, self.tools or [], prompt)
except NotImplementedError as e:
message = f"{self.display_name} does not support tool calling." "Please try using a compatible model."
message = f"{self.display_name} does not support tool calling. Please try using a compatible model."
raise NotImplementedError(message) from e
Loading