Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/_core_features/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,38 @@ puts response.content
# => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast."
```

### Tool Choice Control

Control when and how tools are called using `choice` and `parallel` options.

```ruby
chat = RubyLLM.chat(model: 'gpt-4o')

# Basic usage with defaults
chat.with_tools(Weather, Calculator) # uses provider defaults

# Force tool usage, one at a time
chat.with_tools(Weather, Calculator, choice: :required, parallel: false)

# Force specific tool
chat.with_tool(Weather, choice: :weather, parallel: true)
```

**Parameter Values:**
- **`choice`**: Controls tool choice behavior
- `:auto` Model decides whether to use any tools
- `:required` - Model must use one of the provided tools
- `:none` - Disable all tools
- `"tool_name"` - Force a specific tool (e.g., `:weather` for `Weather` tool)
- **`parallel`**: Controls parallel tool calls
- `true` Allow multiple tool calls simultaneously
- `false` - One at a time

If not provided, RubyLLM will use the provider's default behavior for tool choice and parallel tool calls.

> With `:required` or specific tool choices, the tool_choice is automatically reset to `nil` after tool execution to prevent infinite loops.
{: .note }

### Model Compatibility

RubyLLM will attempt to use tools with any model. If the model doesn't support function calling, the provider will return an appropriate error when you call `ask`.
Expand Down
39 changes: 34 additions & 5 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module RubyLLM
class Chat
include Enumerable

attr_reader :model, :messages, :tools, :params, :headers, :schema
attr_reader :model, :messages, :tools, :tool_prefs, :params, :headers, :schema

def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
Expand All @@ -19,6 +19,7 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
@temperature = nil
@messages = []
@tools = {}
@tool_prefs = { choice: nil, parallel: nil }
@params = {}
@headers = {}
@schema = nil
Expand All @@ -44,15 +45,19 @@ def with_instructions(instructions, replace: false)
self
end

def with_tool(tool)
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
def with_tool(tool, choice: nil, parallel: nil)
unless tool.nil?
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
end
update_tool_options(choice:, parallel:)
self
end

def with_tools(*tools, replace: false)
def with_tools(*tools, replace: false, choice: nil, parallel: nil)
@tools.clear if replace
tools.compact.each { |tool| with_tool tool }
update_tool_options(choice:, parallel:)
self
end

Expand Down Expand Up @@ -125,6 +130,7 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
response = @provider.complete(
messages,
tools: @tools,
tool_prefs: @tool_prefs,
temperature: @temperature,
model: @model,
params: @params,
Expand Down Expand Up @@ -200,6 +206,7 @@ def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = result if result.is_a?(Tool::Halt)
end

reset_tool_choice if forced_tool_choice?
halt_result || complete(&)
end

Expand All @@ -208,5 +215,27 @@ def execute_tool(tool_call)
args = tool_call.arguments
tool.call(args)
end

def update_tool_options(choice:, parallel:)
unless choice.nil?
valid_tool_choices = %i[auto none required] + tools.keys
unless valid_tool_choices.include?(choice.to_sym)
raise InvalidToolChoiceError,
"Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}"
end

@tool_prefs[:choice] = choice.to_sym
end

@tool_prefs[:parallel] = !!parallel unless parallel.nil?
end

def forced_tool_choice?
@tool_prefs[:choice] && !%i[auto none].include?(@tool_prefs[:choice])
end

def reset_tool_choice
@tool_prefs[:choice] = nil
end
end
end
1 change: 1 addition & 0 deletions lib/ruby_llm/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def initialize(response = nil, message = nil)
# Error classes for non-HTTP errors
class ConfigurationError < StandardError; end
class InvalidRoleError < StandardError; end
class InvalidToolChoiceError < StandardError; end
class ModelNotFoundError < StandardError; end
class UnsupportedAttachmentError < StandardError; end

Expand Down
6 changes: 5 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def configuration_requirements
self.class.configuration_requirements
end

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil,
tool_prefs: nil, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = Utils.deep_merge(
render_payload(
messages,
tools: tools,
tool_prefs: tool_prefs,
temperature: normalized_temperature,
model: model,
stream: block_given?,
Expand All @@ -58,6 +61,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
sync_response @connection, payload, headers
end
end
# rubocop:enable Metrics/ParameterLists

def list_models
response = @connection.get models_url
Expand Down
8 changes: 8 additions & 0 deletions lib/ruby_llm/providers/anthropic/capabilities.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def supports_functions?(model_id)
model_id.match?(/claude-3/)
end

def supports_tool_choice?(_model_id)
true
end

def supports_tool_parallel_control?(_model_id)
true
end

def supports_json_mode?(model_id)
model_id.match?(/claude-3/)
end
Expand Down
15 changes: 11 additions & 4 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:,
temperature:, model:, stream: false, schema: nil)
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)

build_base_payload(chat_messages, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:, temperature:)
add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
Expand All @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream)
}
end

def add_optional_fields(payload, system_content:, tools:, temperature:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
if tools.any?
payload[:tools] = tools.values.map { |t| Tools.function_for(t) }
payload[:tool_choice] = Tools.build_tool_choice(tool_prefs) unless tool_prefs[:choice].nil?
end

payload[:system] = system_content unless system_content.empty?
payload[:temperature] = temperature unless temperature.nil?
end
Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def clean_parameters(parameters)
def required_parameters(parameters)
parameters.select { |_, param| param.required }.keys
end

def build_tool_choice(tool_prefs)
tool_choice = tool_prefs[:choice]
parallel_tool_calls = tool_prefs[:parallel]

{
type: case tool_choice
when :auto, :none
tool_choice
when :required
:any
else
:tool
end
}.tap do |tc|
tc[:name] = tool_choice if tc[:type] == :tool
tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil?
end
end
end
end
end
Expand Down
8 changes: 8 additions & 0 deletions lib/ruby_llm/providers/bedrock/capabilities.rb
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def supports_functions?(model_id)
model_id.match?(/anthropic\.claude/)
end

def supports_tool_choice?(model_id)
model_id.match?(/anthropic\.claude/)
end

def supports_tool_parallel_control?(_model_id)
false
end

def supports_audio?(_model_id)
false
end
Expand Down
8 changes: 6 additions & 2 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def completion_url
"model/#{@model_id}/invoke"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false,
schema: nil)
@model_id = model.id

system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
system_content = Anthropic::Chat.build_system_content(system_messages)

build_base_payload(chat_messages, model).tap do |payload|
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:)
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_prefs:,
temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def build_base_payload(chat_messages, model)
{
Expand Down
8 changes: 8 additions & 0 deletions lib/ruby_llm/providers/deepseek/capabilities.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def supports_functions?(model_id)
model_id.match?(/deepseek-chat/)
end

def supports_tool_choice?(_model_id)
true
end

def supports_tool_parallel_control?(_model_id)
false
end

def supports_json_mode?(_model_id)
false
end
Expand Down
8 changes: 8 additions & 0 deletions lib/ruby_llm/providers/gemini/capabilities.rb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def supports_functions?(model_id)
model_id.match?(/gemini|pro|flash/)
end

def supports_tool_choice?(_model_id)
true
end

def supports_tool_parallel_control?(_model_id)
false
end

def supports_json_mode?(model_id)
if model_id.match?(/text-embedding|embedding-001|aqa|imagen|gemini-2\.0-flash-lite|gemini-2\.5-pro-exp-03-25/)
return false
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def completion_url
"models/#{@model}:generateContent"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false, schema: nil)
@model = model.id
payload = {
contents: format_messages(messages),
Expand All @@ -25,9 +26,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema)
end

payload[:tools] = format_tools(tools) if tools.any?
if tools.any?
payload[:tools] = format_tools(tools)
# Gemini doesn't support controlling parallel tool calls
payload[:toolConfig] = build_tool_config(tool_prefs[:choice]) unless tool_prefs[:choice].nil?
end

payload
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

private

Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ def param_type_for_gemini(type)
else 'STRING'
end
end

def build_tool_config(tool_choice)
{
functionCallingConfig: {
mode: forced_tool_choice?(tool_choice) ? 'any' : tool_choice
}.tap do |config|
# Use allowedFunctionNames to simulate specific tool choice
config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice)
end
}
end

def forced_tool_choice?(tool_choice)
tool_choice == :required || specific_tool_choice?(tool_choice)
end

def specific_tool_choice?(tool_choice)
!%i[auto none required].include?(tool_choice)
end
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions lib/ruby_llm/providers/gpustack.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def local?
def configuration_requirements
%i[gpustack_api_base]
end

def capabilities
GPUStack::Capabilities
end
end
end
end
Expand Down
20 changes: 20 additions & 0 deletions lib/ruby_llm/providers/gpustack/capabilities.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# frozen_string_literal: true

module RubyLLM
module Providers
class GPUStack
# Determines capabilities for GPUStack models
module Capabilities
module_function

def supports_tool_choice?(_model_id)
false
end

def supports_tool_parallel_control?(_model_id)
false
end
end
end
end
end
Loading