Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rfc: callback changes #1165

Merged
merged 4 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def _take_next_step(
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
Expand All @@ -415,15 +418,15 @@ def _take_next_step(
llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation
observation = tool.run(
output,
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
)
else:
observation = InvalidTool().run(
output,
output.tool_input,
verbose=self.verbose,
color=None,
llm_prefix="",
Expand Down Expand Up @@ -451,6 +454,9 @@ async def _atake_next_step(
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
self.callback_manager.on_agent_action(
output, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
Expand All @@ -459,15 +465,15 @@ async def _atake_next_step(
llm_prefix = "" if return_direct else self.agent.llm_prefix
# We then call the tool on the tool input to get an observation
observation = await tool.arun(
output,
output.tool_input,
verbose=self.verbose,
color=color,
llm_prefix=llm_prefix,
observation_prefix=self.agent.observation_prefix,
)
else:
observation = await InvalidTool().arun(
output,
output.tool_input,
verbose=self.verbose,
color=None,
llm_prefix="",
Expand Down
47 changes: 40 additions & 7 deletions langchain/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def on_chain_error(

@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uber nit: tool_input seems like a better name?

) -> Any:
"""Run when tool starts running."""

Expand All @@ -86,6 +86,10 @@ def on_tool_error(
def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""

@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""

@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
Expand Down Expand Up @@ -203,15 +207,24 @@ def on_chain_error(
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_start(serialized, action, **kwargs)
handler.on_tool_start(serialized, input_str, **kwargs)

def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)

def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running."""
Expand Down Expand Up @@ -293,7 +306,7 @@ async def on_chain_error(
"""Run when chain errors."""

async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""

Expand All @@ -308,6 +321,9 @@ async def on_tool_error(
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""

async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""

async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""

Expand Down Expand Up @@ -452,7 +468,7 @@ async def on_chain_error(
async def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
Expand All @@ -461,12 +477,12 @@ async def on_tool_start(
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, action, **kwargs)
await handler.on_tool_start(serialized, input_str, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_tool_start, serialized, action, **kwargs
handler.on_tool_start, serialized, input_str, **kwargs
),
)

Expand Down Expand Up @@ -514,6 +530,23 @@ async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None
None, functools.partial(handler.on_text, text, **kwargs)
)

async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_action):
await handler.on_agent_action(action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_action, action, **kwargs
),
)

async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions langchain/callbacks/openai_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def on_chain_error(
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
Expand Down Expand Up @@ -92,6 +91,10 @@ def on_text(
"""Run when agent ends."""
pass

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass

def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
Expand Down
9 changes: 7 additions & 2 deletions langchain/callbacks/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,16 @@ def on_chain_error(
self._callback_manager.on_chain_error(error, **kwargs)

def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, action, **kwargs)
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
with self._lock:
self._callback_manager.on_agent_action(action, **kwargs)

def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
Expand Down
11 changes: 8 additions & 3 deletions langchain/callbacks/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ def on_chain_error(
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
"""Do nothing."""
pass

def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color if color else self.color)

def on_tool_end(
Expand Down
6 changes: 5 additions & 1 deletion langchain/callbacks/streaming_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ def on_chain_error(
"""Run when chain errors."""

def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass

def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""

Expand Down
6 changes: 5 additions & 1 deletion langchain/callbacks/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ def on_chain_error(
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))

Expand Down
11 changes: 8 additions & 3 deletions langchain/callbacks/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def on_chain_error(
self._end_trace()

def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Start a trace for a tool run."""
if self._session is None:
Expand All @@ -209,8 +209,9 @@ def on_tool_start(

tool_run = ToolRun(
serialized=serialized,
action=action.tool,
tool_input=action.tool_input,
# TODO: this is duplicate info as above, not needed.
action=str(serialized),
Copy link
Collaborator

@agola11 agola11 Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can rm this from the tracing backend and make the changes here next time we release a new tracing version

tool_input=input_str,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
Expand Down Expand Up @@ -250,6 +251,10 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Handle an agent finish message."""
pass

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass


class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface."""
Expand Down
18 changes: 8 additions & 10 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import AgentAction


class BaseTool(BaseModel):
Expand Down Expand Up @@ -45,12 +44,11 @@ async def _arun(self, tool_input: str) -> str:

def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
agent_action = AgentAction(tool_input=tool_input, tool=self.name, log="")
return self.run(agent_action)
return self.run(tool_input)

def run(
self,
action: AgentAction,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
Expand All @@ -61,13 +59,13 @@ def run(
verbose = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
observation = self._run(action.tool_input)
observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
Expand All @@ -78,7 +76,7 @@ def run(

async def arun(
self,
action: AgentAction,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
Expand All @@ -90,22 +88,22 @@ async def arun(
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
action,
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(action.tool_input)
observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose)
Expand Down
12 changes: 8 additions & 4 deletions tests/unit_tests/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
Expand Down Expand Up @@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
Expand Down
Loading