Skip to content

shobrook/saplings

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🌳 Saplings

A framework for building agents that use search algorithms to complete tasks.

By incorporating search, an agent can explore different tool-use trajectories and find the optimal path. This ability to look multiple steps ahead reduces errors and boosts overall task performance –– especially on complex reasoning problems, like code generation or navigating a website. With saplings, you can build search into your agents with just a couple lines of code.

  • Supports popular search algorithms: Monte Carlo Tree Search (MCTS), A*, and greedy best-first search
  • Uses OpenAI function calling under the hood
  • Full control over the evaluation function, prompts, search parameters, etc.

Demo

Why add search?

Chain-of-thought/ReAct-style agents don't work well because they're vulnerable to compounding errors. Even a small mistake early in the loop can snowball and ruin the final output. Adding tree search gives your agent lookahead and backtracking abilities, making it easier to recover from such mistakes. And as compute becomes cheaper, it will become table stakes for agents to use inference-time search.


Installation

$ pip install saplings

Quickstart

Below is a simple agent implementing Monte Carlo tree search (MCTS). It's equipped with a multiplication tool to solve tricky arithmetic problems.

from saplings.examples import MultiplicationTool
from saplings.llms import OpenAI
from saplings import AStarAgent, Evaluator

model = OpenAI(model="gpt-4o", api_key="YOUR_API_KEY")
evaluator = Evaluator(model)
tools = [MultiplicationTool()]

agent = MonteCarloAgent(tools, model, evaluator)
messages, _, _ = agent.run("Let x = 9418.343 * 8.11 and y = 2x. Calculate (xy)(x^2).")

This is the "bare minimum" for setting up a search agent with saplings –– just a few lines of code. There are a lot more parameters you can control, all covered in the docs. But let's first walk through the basics of creating your own tools and configuring an agent.

Creating a tool

Tools are what your agent will use to perform a task or answer a query. Each tool must extend the Tool base class and implement a few variables and methods. Here's an example of a simple tool that multiples two numbers together:

from saplings.abstract import Tool

class MultiplicationTool(Tool):
   def __init__(self, **kwargs):
      self.name = "multiply"
      self.description = "Multiplies two numbers and returns the result number."
      self.parameters = {
         "type": "object",
         "properties": {
            "a": {
               "type": "number",
               "description": "The number to multiply."
            },
            "b": {
               "type": "number",
               "description": "The number to multiply by."
            }
         },
         "required": ["a", "b"],
         "additionalProperties": False
      }
      self.is_terminal = False

   async def run(self, a, b, **kwargs):
      return a * b

Variables:

The instance variables in the class tell the agent when and how to call the tool. If you've used OpenAI function calling before, most of this should be familiar to you.

  • name (str): Name of the tool.
  • description (str): Description of what the tool does and when to call it.
  • parameters (dict): Arguments for the tool as a JSON schema.
  • is_terminal (bool): If True, calling this tool will terminate a search trajectory –– meaning, no subsequent tools can be called after this one. This is typically used for tools that generate a final output for the user (e.g. an answer to a question). More on this here.

run() method:

This is what actually executes the tool when the agent calls it. Arguments should be the same as the input parameters in the tool schema.

Advanced options:

There are additional things you can do with tools, such as accessing the agent's memory during tool execution, or controlling how tool output is shown to the model (vs. how it's stored in memory). You can read about these options here.

Configuring an agent

Choosing a model:

First, you need to choose a model for the agent to use. Saplings only supports OpenAI models right now, but Anthropic and Groq are on the roadmap.

from saplings.llms import OpenAI

model = OpenAI(model="gpt-4o", api_key="YOUR_API_KEY") # Defaults to os.getenv("OPENAI_API_KEY") if empty

Note: if you pass in additional **kwargs, they will get passed to all the chat completion calls made with this model.

Setting up the evaluator:

This is what will guide the search process. The evaluator takes a search trajectory (i.e. a list of OpenAI messages) and returns a score between 0 and 1, indicating how promising the trajectory is. By default, a score of 1.0 means the agent has solved the problem and can terminate the search. You can change the solution cutoff by setting the threshold parameter in the agent –– more on that here.

from saplings import Evaluator

evaluator = Evaluator(model)

The default evaluator provided by saplings uses a LLM (i.e. the model you pass in above) to score trajectories. The Evaluator object has parameters that let you control things like the system prompt used and the sampling rate. You can also define your own custom evaluator if necessary. Read more about evaluators here.

Choosing an agent/search algorithm:

Once your tools, model, and evaluator are ready, you can simply plug them into a saplings agent. There are multiple to choose from, each implementing their own tree search algorithm: MonteCarloAgent, AStarAgent, and GreedyAgent. There's also a regular chain-of-thought agent available, COTAgent, which does not implement any search. Each agent has their own advantages and disadvantages, which you can read about here.

from treeact import MonteCarloAgent

agent = MonteCarloAgent(tools, model, evaluator)

This will initialize your agent. To actually run it on an input, call the run method. To run it asynchronously, call the run_async method.

messages, score, is_solution = agent.run("What's 2 * 2?") # await agent.run_async("What's 2 * 2?")

The output is a list of messages representing the best tool-use trajectory, the final score of the trajectory (as given by the evaluator), and whether or not the search terminated because the evaluator deemed the trajectory a solution to the prompt. The messages are Message objects, which are special objects native to saplings that wrap OpenAI messages.

Notably, there are many more parameters you can set for the agent, such as the system prompt that governs it.

Docs

Agents

Parameters

Every agent in saplings has the same parameters, listed below:

  1. tools (List[Tool]): List of tools your agent can use.
  2. model (Model): LLM provider that your agent will use to call tools.
  3. evaluator (BaseEvaluator): Evaluation function that the agent will use to guide the search process.
  4. prompt (str): System prompt for the agent.
  5. b_factor (int): Branching factor, i.e. the number of potential next tool calls to evaluate at each step in a search trajectory. Note that this parameter does not do anything for COTAgent.
  6. max_depth (int): Maximum depth of the search tree, indicating how many levels the agent can explore.
  7. threshold (float): A cutoff value for the evaluation function. If a trajectory's evaluation score is above this threshold, the search will terminate and that trajectory will be accepted as the solution.
  8. verbose (bool): Whether to print logging statements when you run the agent.
  9. tool_choice ("auto" | "required"): Same as the tool_choice parameter in the OpenAI chat completions function. Indicates whether the model must always call a tool, or if it can decide to generate a normal response instead.
  10. parallel_tool_calls (bool): Same as the parallel_tool_calls parameter in the OpenAI chat completions function. Indicates whether the model can generate multiple tool calls in a single completion request.

GreedyAgent: Greedy best-first serach

This agent implements a greedy best-first search. It's the fastest and cheapest search agent, in terms of LLM calls, but it's also incapable of backtracking, thus making it the least effective agent. GreedyAgent works by taking the input and generating a set of candidate tool calls. It executes each tool call and evaluates their outputs. Then, it picks the best tool call based on its evaluation and generates a set of candidate next tool calls. It repeats this process until a termination condition is met.

MonteCarloAgent: Monte Carlo tree search

Demo

This agent implements the Monte Carlo tree search (MCTS) algorithm, based on the paper Language Agent Tree Search (Zhou, et. al). It is the most effective agent you can build with saplings, but also the slowest and most expensive (in terms of LLM calls) in the worst case. The primary advantage of this agent is its ability to balance exploration and exploitation, allowing it to efficiently find optimal trajectories by using past experiences and adjusting its strategy accordingly.

Note that, besides the parameters listed above, this agent has one additional parameter:

  1. max_rollouts (int, default = 10): This controls the maximum # of simulations the agent can perform.

AStarAgent: A* search

Demo

Implements a variation of the A* pathfinding algorithm, based on the paper Tree Search for Language Model Agents (Koh, et al.). Unlike GreedyAgent, this agent makes more LLM calls in the worst case, but is capable of backtracking and recovering from mistakes. However, unlike MonteCarloAgent, it does not update its search strategy based on the trajectories it has already explored. Oftentimes, AStarAgent is the perfect middle-ground between GreedyAgent (dumb but fast) and MonteCarloAgent (smart but slow).

COTAgent: Chain-of-thought (no search)

This is a standard tool-calling agent and does not implement any search. It takes an input, calls a tool, then uses the tool output to inform the next tool call, and so on until a termination condition is met. Think of COTAgent as a baseline to compare your search agents to.

The Message object

Messages are a core data structure in saplings. They are essentially equivalent to OpenAI messages (e.g. user input, tool calls, tool responses, assistant responses), with a few extra properties and helper methods. A list of messages represents a search trajectory. When you run an agent, it will return a list of messages representing the best trajectory it found.

Saplings messages can be easily converted into OpenAI messages using the to_openai_message() method.

messages, _, _ = agent.run("This is my prompt!")
messages = [message.to_openai_response() for message in messages]

print(messages)
# [{"role": "user", "content": "This is my prompt!"}, ..., {"role": "assistant", "content": "This is a response!"}]

Message objects have only one additional attribute that OpenAI messages don't have. If a message represents a tool response, it will have a raw_output property that contains the output of that tool. What's stored here may be different than the tool response that gets shown to the model, which is stored in the content property.

Termination conditions

Every tool has an is_terminal property. This is a boolean flag that tells the agent if calling the tool should terminate a search trajectory. If it's True, no subsequent tool calls can be made after the tool is invoked, and the agent will terminate that search trajectory. Terminal tools are typically used to generate some sort of final output for the user (e.g. an answer to a question).

We say that an agent can self-terminate if it has at least one terminal tool, OR if the tool_choice parameter is set to "auto." In the latter case, this means that calling a tool is optional for the agent, and instead of a tool call, it can generate a regular assistant response to the input prompt. We consider such a response to also terminate a search trajectory.

If an agent cannot self-terminate, then a search trajectory will only ever terminate if either a maximum depth is reached (set by the max_depth parameter), or the evaluator marks a trajectory as solved (i.e. the score is >= the agent's threshold parameter) –– in which case the entire search itself terminates.

An important point of confusion here: even if an evaluator marks a trajectory as solved, the search may not terminate if the agent can self-terminate. This happens when a trajectory ends with a non-terminal tool call (or a non-assistant response, in the case when tool use is optional) but is still given a score above the solution threshold. In this case, the search will continue unless until a terminal state is reached that is marked as solved. If no terminal state is ever reached, the trajectory with the best score is returned. If no solution is ever found, and there is one trajectory with a terminal state and another with a non-terminal state but a higher score, the terminal trajectory is preferred and returned.

Advanced tool options

Accessing agent memory

In some cases, running your tool may depend on the output of the previous tools your agent has used, or the user input itself. If this is the case, you can access the agent's current search trajectory in the run method when you implement your tool. Simply use kwargs.get("trajectory"). This will return a list of Message objects, which are wrappers around OpenAI messages.

The format_output() method

In some cases, it makes sense for the raw output of a tool to be separated from the output that's shown to the model. By default, the output of run() is what's shown to the model. But you can add the optional format_output method to your tool class to change how the output is presented to the agent. For example, in our quickstart example, instead of seeing the multiplication result N, you might want the model to see "A * B = N" so the agent can more easily keep track of what numbers have been multiplied. Here's how you'd modify the tool to do that:

class MultiplicationTool(object):
   ...

   async def run(self, a, b, **kwargs):
      return {"a": a, "b": "result": a * b}

   def format_output(self, output):
      a, b = output['a'], output['b']
      result = output['result']
      return f"{a} * {b} = {result}"

The unformatted output of the tool is still stored in the agent's memory. It can be access via the raw_output property of the Message object that represents the tool response.

Custom evaluators

Every agent implements a heuristic search algorithm, meaning that it uses some heuristic or value function to guide the search. By default, saplings offers the Evaluator object, which evaluates a search trajectory using a LLM. It takes a trajectory (i.e. a list of OpenAI messages) as input and returns a score between 0 and 1 which tells the agent if its on the right track or not, along with some written justification for the score.

The Evaluator object has the following parameters:

  1. model (Model): The LLM used to generate the score.
  2. n_samples (int): The number of scores to generate for a given trajectory. Equivalent to the n parameter in an OpenAI chat completion. If it's greater than 1, multiple candidate scores will be generated for a given trajectory and then averaged to return the final score. Making this greater than 1 is equivalent to enabling self-consistency in the evaluation process.
  3. prompt (str): The system prompt that tells the model how it should evaluate a trajectory and generate a score.

In most cases, simply customizing this object will be sufficient, but in some situations it makes sense to build your own evaluator. For example, if you're building a coding agent, you may want to evaluate a search trajectory using some external feedback, such as whether the code compiles or whether a set of unit tests are passing. To build a custom evaluator, you must extend the Evaluator base class and implement a run method. This method must take in a list of Message objects as input, representing a search trajectory, and return an EvaluationDTO object as output. This object has two properties: score (a value between 0 and 1) and reasoning (an optional string with written justification for the score).

from saplings.abstract import Evaluator
from saplings.dtos import EvaluationDTO

class CustomEvaluator(Evaluator):
   def __init__(self):
      pass

   async def run(self, trajectory: List[Message]) -> EvaluationDTO:
      # Implement this
      return EvaluationDTO(score=1.0, reasoning="Justification goes here.")

Note that the trajectory will always contain the original input message, every tool call, and every tool response. For the tool responses, you can access the raw output of the tool using the Message.raw_output property, discussed in more detail here.

Each agent has a threshold parameter, which determines the minimum score at which to terminate the search and deem a trajectory as a solution. By default, it is 1.0, so you should keep this in mind when designing your evaluator.

Roadmap

  1. Support for chat history
  2. Support for Anthropic and Groq models
  3. Allow dynamic system prompts and tool schemas (i.e. prompts that change as the agent progresses)
  4. Support for vision agents
  5. Add an llm_call_budget parameter to every agent

Mission: More inference-time compute makes agents smarter. And as models get cheaper and faster, search will become more viable to use in production. Let's build the easiest and most powerful framework for building search-enabled agents.

Note from the author

One of my other open-source packages used to be called saplings. It has since been renamed to syntaxis and is now associated with the package of the same name on PyPi.

Releases

No releases published

Packages

No packages published

Languages