Skip to content

Commit d6f36f3

Browse files
committed
Add with_provider_options and use that for opting into caching
1 parent 00e69ae commit d6f36f3

File tree

14 files changed

+271
-45
lines changed

14 files changed

+271
-45
lines changed

docs/_core_features/chat.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -505,21 +505,21 @@ Refer to the [Working with Models Guide]({% link _advanced/models.md %}) for det
505505
### Enabling
506506
For Anthropic models, you can opt-in to prompt caching which is documented more fully in the [Anthropic API docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
507507

508-
Enable prompt caching using the `cache_prompts` method on your chat instance:
508+
Enable prompt caching using the `with_provider_options` method on your chat instance:
509509

510510
```ruby
511511
chat = RubyLLM.chat(model: 'claude-3-5-haiku-20241022')
512512

513513
# Enable caching for different types of content
514-
chat.cache_prompts(
515-
system: true, # Cache system instructions
516-
user: true, # Cache user messages
517-
tools: true # Cache tool definitions
514+
chat.with_provider_options(
515+
cache_last_system_prompt: true, # Cache system instructions
516+
cache_last_user_prompt: true, # Cache user messages
517+
cache_tools: true # Cache tool definitions
518518
)
519519
```
520520

521521
### Checking cached token counts
522-
For Anthropic, OpenAI, and Gemini, you can see the number of tokens read from cache by looking at the `cached_tokens` property on the output messages.
522+
For Anthropic, OpenAI, and Gemini, you can see the number of tokens read from cache by looking at the `cached_tokens` property on the output messages.
523523

524524
For Anthropic, you can see the tokens written to cache by looking at the `cache_creation_tokens` property.
525525

lib/ruby_llm/active_record/chat_methods.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def with_schema(...)
135135
self
136136
end
137137

138-
def cache_prompts(...)
139-
to_llm.cache_prompts(...)
138+
def with_provider_options(...)
139+
to_llm.with_provider_options(...)
140140
self
141141
end
142142

lib/ruby_llm/chat.rb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def with_schema(schema)
9898
self
9999
end
100100

101+
def with_provider_options(options)
102+
@provider.with_options(options)
103+
self
104+
end
105+
101106
def on_new_message(&block)
102107
@on[:new_message] = block
103108
self
@@ -122,18 +127,12 @@ def each(&)
122127
messages.each(&)
123128
end
124129

125-
def cache_prompts(system: false, user: false, tools: false)
126-
@cache_prompts = { system: system, user: user, tools: tools }
127-
self
128-
end
129-
130130
def complete(&) # rubocop:disable Metrics/PerceivedComplexity
131131
response = @provider.complete(
132132
messages,
133133
tools: @tools,
134134
temperature: @temperature,
135135
model: @model,
136-
cache_prompts: @cache_prompts.dup,
137136
params: @params,
138137
headers: @headers,
139138
schema: @schema,

lib/ruby_llm/provider.rb

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@ module RubyLLM
55
class Provider
66
include Streaming
77

8-
attr_reader :config, :connection
8+
attr_reader :config, :connection, :options
99

1010
def initialize(config)
1111
@config = config
1212
ensure_configured!
1313
@connection = Connection.new(self, @config)
14+
@options = self.class.respond_to?(:options) ? self.class.options.new : nil
15+
end
16+
17+
def with_options(options)
18+
@options = options.is_a?(self.class.options) ? options : self.class.options.new(**options)
19+
self
1420
end
1521

1622
def api_base
@@ -37,8 +43,7 @@ def configuration_requirements
3743
self.class.configuration_requirements
3844
end
3945

40-
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, # rubocop:disable Metrics/ParameterLists
41-
cache_prompts: { system: false, user: false, tools: false }, &)
46+
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
4247
normalized_temperature = maybe_normalize_temperature(temperature, model)
4348

4449
payload = Utils.deep_merge(
@@ -47,7 +52,6 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
4752
tools: tools,
4853
temperature: normalized_temperature,
4954
model: model,
50-
cache_prompts: cache_prompts,
5155
stream: block_given?,
5256
schema: schema
5357
),

lib/ruby_llm/providers/anthropic.rb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,26 @@ def headers
2222
}
2323
end
2424

25+
# Options for the Anthropic provider
26+
class Options
27+
attr_accessor :cache_last_system_prompt, :cache_last_user_prompt, :cache_tools
28+
29+
def initialize(cache_last_system_prompt: false, cache_last_user_prompt: false, cache_tools: false)
30+
@cache_last_system_prompt = cache_last_system_prompt
31+
@cache_last_user_prompt = cache_last_user_prompt
32+
@cache_tools = cache_tools
33+
end
34+
end
35+
2536
class << self
2637
def capabilities
2738
Anthropic::Capabilities
2839
end
2940

41+
def options
42+
Anthropic::Options
43+
end
44+
3045
def configuration_requirements
3146
%i[anthropic_api_key]
3247
end

lib/ruby_llm/providers/anthropic/chat.rb

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,29 @@ def completion_url
1111
'/v1/messages'
1212
end
1313

14-
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil, # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
15-
cache_prompts: { system: false, user: false, tools: false })
14+
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
1615
system_messages, chat_messages = separate_messages(messages)
17-
system_content = build_system_content(system_messages, cache: cache_prompts[:system])
16+
system_content = build_system_content(system_messages, options)
1817

19-
build_base_payload(chat_messages, model, stream, cache: cache_prompts[:user]).tap do |payload|
20-
add_optional_fields(payload, system_content:, tools:, temperature:,
21-
cache_tools: cache_prompts[:tools])
18+
build_base_payload(chat_messages, model, stream).tap do |payload|
19+
add_optional_fields(payload, system_content:, tools:, temperature:, options:)
2220
end
2321
end
2422

2523
def separate_messages(messages)
2624
messages.partition { |msg| msg.role == :system }
2725
end
2826

29-
def build_system_content(system_messages, cache: false)
27+
def build_system_content(system_messages, options)
3028
system_messages.flat_map.with_index do |msg, idx|
31-
message_cache = cache if idx == system_messages.size - 1
29+
message_cache = options.cache_last_system_prompt if idx == system_messages.size - 1
3230
format_system_message(msg, cache: message_cache)
3331
end
3432
end
3533

36-
def build_base_payload(chat_messages, model, stream, cache: false)
34+
def build_base_payload(chat_messages, model, stream)
3735
messages = chat_messages.map.with_index do |msg, idx|
38-
message_cache = cache if idx == chat_messages.size - 1
36+
message_cache = options.cache_last_user_prompt if idx == chat_messages.size - 1
3937
format_message(msg, cache: message_cache)
4038
end
4139

@@ -47,10 +45,10 @@ def build_base_payload(chat_messages, model, stream, cache: false)
4745
}
4846
end
4947

50-
def add_optional_fields(payload, system_content:, tools:, temperature:, cache_tools: false)
48+
def add_optional_fields(payload, system_content:, tools:, temperature:, options:)
5149
if tools.any?
5250
tool_definitions = tools.values.map { |t| Tools.function_for(t) }
53-
tool_definitions[-1][:cache_control] = { type: 'ephemeral' } if cache_tools
51+
tool_definitions[-1][:cache_control] = { type: 'ephemeral' } if options.cache_tools
5452
payload[:tools] = tool_definitions
5553
end
5654

lib/ruby_llm/providers/bedrock.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ def build_headers(signature_headers, streaming: false)
6767
'Accept' => accept_header
6868
)
6969
end
70-
7170
class << self
7271
def capabilities
7372
Bedrock::Capabilities
7473
end
7574

75+
def options
76+
Anthropic::Options
77+
end
78+
7679
def configuration_requirements
7780
%i[bedrock_api_key bedrock_secret_key bedrock_region]
7881
end

lib/ruby_llm/providers/bedrock/chat.rb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,19 @@ def completion_url
3939
"model/#{@model_id}/invoke"
4040
end
4141

42-
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil, # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
43-
cache_prompts: { system: false, user: false, tools: false })
42+
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
4443
@model_id = model.id
4544

4645
system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
47-
system_content = Anthropic::Chat.build_system_content(system_messages, cache: cache_prompts[:system])
46+
system_content = Anthropic::Chat.build_system_content(system_messages, options)
4847

49-
build_base_payload(chat_messages, model, cache: cache_prompts[:user]).tap do |payload|
48+
build_base_payload(chat_messages, model).tap do |payload|
5049
Anthropic::Chat.add_optional_fields(
5150
payload,
5251
system_content:,
5352
tools:,
5453
temperature:,
55-
cache_tools: cache_prompts[:tools]
54+
options:
5655
)
5756
end
5857
end

lib/ruby_llm/providers/gemini/chat.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def completion_url
1111
"models/#{@model}:generateContent"
1212
end
1313

14-
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil, cache_prompts: {}) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
14+
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
1515
@model = model.id
1616
payload = {
1717
contents: format_messages(messages),

lib/ruby_llm/providers/mistral/chat.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def format_role(role)
1212
end
1313

1414
# rubocop:disable Metrics/ParameterLists
15-
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil, cache_prompts: {}) # rubocop:disable Metrics/ParameterLists
15+
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists
1616
payload = super
1717
payload.delete(:stream_options)
1818
payload

0 commit comments

Comments
 (0)