Skip to content

Commit ca52dfa

Browse files
committed
Handle Gemini maxOutputTokens attribute properly
1 parent a0efaa4 commit ca52dfa

File tree

5 files changed

+181
-1
lines changed

5 files changed

+181
-1
lines changed

lib/ruby_llm/provider.rb

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def headers
2121
{}
2222
end
2323

24+
def parameter_mappings
25+
{}
26+
end
27+
2428
def slug
2529
self.class.slug
2630
end
@@ -39,6 +43,7 @@ def configuration_requirements
3943

4044
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
4145
normalized_temperature = maybe_normalize_temperature(temperature, model)
46+
transformed_params = apply_parameter_mappings(params)
4247

4348
payload = Utils.deep_merge(
4449
render_payload(
@@ -49,7 +54,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
4954
stream: block_given?,
5055
schema: schema
5156
),
52-
params
57+
transformed_params
5358
)
5459

5560
if block_given?
@@ -192,6 +197,28 @@ def configured_remote_providers(config)
192197

193198
private
194199

200+
def apply_parameter_mappings(params)
201+
return params if parameter_mappings.empty?
202+
203+
transformed = params.dup
204+
205+
parameter_mappings.each do |source_key, target_path|
206+
next unless transformed.key?(source_key)
207+
208+
value = transformed.delete(source_key)
209+
*keys, last_key = target_path
210+
211+
target = keys.inject(transformed) do |hash, key|
212+
hash[key] = {} unless hash[key].is_a?(Hash)
213+
hash[key]
214+
end
215+
216+
target[last_key] = value
217+
end
218+
219+
transformed
220+
end
221+
195222
def try_parse_json(maybe_json)
196223
return maybe_json unless maybe_json.is_a?(String)
197224

lib/ruby_llm/providers/gemini.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def headers
2222
}
2323
end
2424

25+
private
26+
27+
def parameter_mappings
28+
{
29+
max_tokens: %i[generationConfig maxOutputTokens]
30+
}
31+
end
32+
2533
class << self
2634
def capabilities
2735
Gemini::Capabilities

spec/fixtures/vcr_cassettes/chat_with_params_gemini_gemini-2_5-flash_automatically_maps_max_tokens_to_maxoutputtokens.yml

Lines changed: 81 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

spec/ruby_llm/chat_request_options_spec.rb

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,24 @@
106106
expect(json_response).to eq({ 'result' => 8 })
107107
end
108108
end
109+
110+
# Provider [:gemini] automatically maps max_tokens to generationConfig.maxOutputTokens
111+
CHAT_MODELS.select { |model_info| model_info[:provider] == :gemini }.each do |model_info|
112+
model = model_info[:model]
113+
provider = model_info[:provider]
114+
it "#{provider}/#{model} automatically maps max_tokens to maxOutputTokens" do
115+
chat = RubyLLM
116+
.chat(model: model, provider: provider)
117+
.with_params(max_tokens: 100)
118+
119+
response = chat.ask('Say hello in 3 words.')
120+
121+
request_body = JSON.parse(response.raw.env.request_body)
122+
expect(request_body.dig('generationConfig', 'maxOutputTokens')).to eq(100)
123+
expect(request_body).not_to have_key('max_tokens')
124+
125+
expect(response.content).to be_present
126+
end
127+
end
109128
end
110129
end

spec/ruby_llm/providers/gemini/chat_spec.rb

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,49 @@
233233
# Verify our implementation correctly sums both token types
234234
expect(response.output_tokens).to eq(candidates_tokens + thoughts_tokens)
235235
end
236+
237+
describe 'parameter mapping' do
238+
let(:provider) do
239+
config = RubyLLM::Configuration.new
240+
config.gemini_api_key = 'test_key'
241+
RubyLLM::Providers::Gemini.new(config)
242+
end
243+
244+
it 'maps max_tokens to generationConfig.maxOutputTokens' do
245+
params = { max_tokens: 1000 }
246+
result = provider.send(:apply_parameter_mappings, params)
247+
248+
expect(result).to eq({ generationConfig: { maxOutputTokens: 1000 } })
249+
end
250+
251+
it 'removes max_tokens from the params after mapping' do
252+
params = { max_tokens: 500 }
253+
result = provider.send(:apply_parameter_mappings, params)
254+
255+
expect(result).not_to have_key(:max_tokens)
256+
end
257+
258+
it 'preserves other params while mapping max_tokens' do
259+
params = { max_tokens: 1000, other_param: 'value' }
260+
result = provider.send(:apply_parameter_mappings, params)
261+
262+
expect(result[:other_param]).to eq('value')
263+
expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(1000)
264+
end
265+
266+
it 'merges with existing generationConfig hash' do
267+
params = { max_tokens: 500, generationConfig: { temperature: 0.7 } }
268+
result = provider.send(:apply_parameter_mappings, params)
269+
270+
expect(result.dig(:generationConfig, :temperature)).to eq(0.7)
271+
expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(500)
272+
end
273+
274+
it 'handles params without max_tokens' do
275+
params = { other: 'value', custom: 123 }
276+
result = provider.send(:apply_parameter_mappings, params)
277+
278+
expect(result).to eq({ other: 'value', custom: 123 })
279+
end
280+
end
236281
end

0 commit comments

Comments
 (0)