Skip to content

Commit 03c4fd6

Browse files
committed
Fix SSE parsing in Gemini provider
This fixes tool calling issues on some gemini models. The bug here won't be observed if the server only returns small, completed chunks. We should not process SSE *byte*stream line-by-line. Therefore, we need to have correct buffering. Instead of reinventing the wheels, we introduce new dependencies (reqwest-eventsource) for parsing SSE chunks. We also set id in tool_call to its function name instead of empty string Disclaimer: I did vibe code this (reviewed, tested internally though)
1 parent d956ea5 commit 03c4fd6

File tree

5 files changed

+375
-57
lines changed

5 files changed

+375
-57
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ tracing-subscriber = "0.3.19"
9090
uuid = "1.17.0"
9191
worker = "0.6"
9292
zerocopy = "0.8.26"
93+
reqwest-eventsource = "0.6.0"
9394

9495
[workspace.metadata.cargo-autoinherit]
9596
# Skip cargo-autoinherit for these packages

rig-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ base64 = { workspace = true }
3838
as-any = { workspace = true }
3939
rmcp = { version = "0.3", optional = true, features = ["client"] }
4040
url = { workspace = true }
41+
reqwest-eventsource = { workspace = true }
4142

4243
[dev-dependencies]
4344
anyhow = { workspace = true }
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
use futures::{Stream, StreamExt};
2+
use rig::providers::gemini;
3+
use rig::{
4+
OneOrMany,
5+
agent::Agent,
6+
client::{CompletionClient, ProviderClient},
7+
completion::{self, CompletionError, CompletionModel, PromptError, ToolDefinition},
8+
message::{AssistantContent, Message, Text, ToolResultContent, UserContent},
9+
streaming::{StreamedAssistantContent, StreamingCompletion},
10+
tool::{Tool, ToolSetError},
11+
};
12+
use schemars::{JsonSchema, schema_for};
13+
use serde::{Deserialize, Serialize};
14+
15+
use std::pin::Pin;
16+
use thiserror::Error;
17+
18+
#[derive(Debug, Error)]
19+
enum StreamingError {
20+
#[error("CompletionError: {0}")]
21+
Completion(#[from] CompletionError),
22+
#[error("PromptError: {0}")]
23+
Prompt(#[from] PromptError),
24+
#[error("ToolSetError: {0}")]
25+
Tool(#[from] ToolSetError),
26+
}
27+
28+
type StreamingResult = Pin<Box<dyn Stream<Item = Result<Text, StreamingError>> + Send>>;
29+
30+
#[tokio::main]
31+
async fn main() -> anyhow::Result<()> {
32+
// tracing_subscriber::registry()
33+
// .with(
34+
// tracing_subscriber::EnvFilter::try_from_default_env()
35+
// .unwrap_or_else(|_| "info".into()),
36+
// )
37+
// .with(tracing_subscriber::fmt::layer())
38+
// .init();
39+
40+
// Create gemini client
41+
let llm_client = gemini::Client::from_env();
42+
43+
// Create agent with a single context prompt and a calculator tools
44+
let calculator_agent = llm_client
45+
.agent("gemini-2.5-flash")
46+
.preamble("You are an calculator. You must use tools to get the user result")
47+
.tool(Add)
48+
.tool(Subtract)
49+
.tool(Multiply)
50+
.tool(Divide)
51+
.build();
52+
53+
// Prompt the agent and get the stream
54+
let mut stream = multi_turn_prompt(
55+
calculator_agent,
56+
"Calculate 2 * (3 + 5) / 9 = ?. Describe the result to me.",
57+
Vec::new(),
58+
)
59+
.await;
60+
61+
custom_stream_to_stdout(&mut stream).await?;
62+
63+
Ok(())
64+
}
65+
66+
async fn multi_turn_prompt<M>(
67+
agent: Agent<M>,
68+
prompt: impl Into<Message> + Send,
69+
mut chat_history: Vec<completion::Message>,
70+
) -> StreamingResult
71+
where
72+
M: CompletionModel + 'static,
73+
<M as CompletionModel>::StreamingResponse: std::marker::Send,
74+
{
75+
let prompt: Message = prompt.into();
76+
77+
(Box::pin(async_stream::stream! {
78+
let mut current_prompt = prompt;
79+
let mut did_call_tool = false;
80+
81+
'outer: loop {
82+
let mut stream = agent
83+
.stream_completion(current_prompt.clone(), chat_history.clone())
84+
.await?
85+
.stream()
86+
.await?;
87+
88+
chat_history.push(current_prompt.clone());
89+
90+
let mut tool_calls = vec![];
91+
let mut tool_results = vec![];
92+
93+
while let Some(content) = stream.next().await {
94+
match content {
95+
Ok(StreamedAssistantContent::Text(text)) => {
96+
yield Ok(Text { text: text.text });
97+
did_call_tool = false;
98+
},
99+
Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
100+
let tool_result =
101+
agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?;
102+
103+
let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
104+
105+
tool_calls.push(tool_call_msg);
106+
tool_results.push((tool_call.id, tool_call.call_id, tool_result));
107+
108+
did_call_tool = true;
109+
// break;
110+
},
111+
Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning })) => {
112+
yield Ok(Text { text: reasoning });
113+
did_call_tool = false;
114+
},
115+
Ok(_) => {
116+
// do nothing here as we don't need to accumulate token usage
117+
}
118+
Err(e) => {
119+
yield Err(e.into());
120+
break 'outer;
121+
}
122+
}
123+
}
124+
125+
// Add (parallel) tool calls to chat history
126+
if !tool_calls.is_empty() {
127+
chat_history.push(Message::Assistant {
128+
id: None,
129+
content: OneOrMany::many(tool_calls).expect("Impossible EmptyListError"),
130+
});
131+
}
132+
133+
// Add tool results to chat history
134+
for (id, call_id, tool_result) in tool_results {
135+
if let Some(call_id) = call_id {
136+
chat_history.push(Message::User {
137+
content: OneOrMany::one(UserContent::tool_result_with_call_id(
138+
id,
139+
call_id,
140+
OneOrMany::one(ToolResultContent::text(tool_result)),
141+
)),
142+
});
143+
} else {
144+
chat_history.push(Message::User {
145+
content: OneOrMany::one(UserContent::tool_result(
146+
id,
147+
OneOrMany::one(ToolResultContent::text(tool_result)),
148+
)),
149+
});
150+
151+
}
152+
153+
}
154+
155+
// Set the current prompt to the last message in the chat history
156+
current_prompt = match chat_history.pop() {
157+
Some(prompt) => prompt,
158+
None => unreachable!("Chat history should never be empty at this point"),
159+
};
160+
161+
if !did_call_tool {
162+
break;
163+
}
164+
}
165+
166+
})) as _
167+
}
168+
169+
/// helper function to stream a completion request to stdout
170+
async fn custom_stream_to_stdout(stream: &mut StreamingResult) -> Result<(), std::io::Error> {
171+
print!("Response: ");
172+
while let Some(content) = stream.next().await {
173+
match content {
174+
Ok(Text { text }) => {
175+
print!("{text}");
176+
std::io::Write::flush(&mut std::io::stdout())?;
177+
}
178+
Err(err) => {
179+
eprintln!("Error: {err:#?}");
180+
}
181+
}
182+
}
183+
println!(); // New line after streaming completes
184+
185+
Ok(())
186+
}
187+
188+
#[derive(Deserialize, JsonSchema)]
189+
struct OperationArgs {
190+
x: i32,
191+
y: i32,
192+
}
193+
194+
#[derive(Debug, thiserror::Error)]
195+
#[error("Math error")]
196+
struct MathError;
197+
198+
#[derive(Deserialize, Serialize)]
199+
struct Add;
200+
impl Tool for Add {
201+
const NAME: &'static str = "add";
202+
203+
type Error = MathError;
204+
type Args = OperationArgs;
205+
type Output = i32;
206+
207+
async fn definition(&self, _prompt: String) -> ToolDefinition {
208+
ToolDefinition {
209+
name: "add".to_string(),
210+
description: "Add x and y together".to_string(),
211+
parameters: serde_json::to_value(schema_for!(OperationArgs))
212+
.expect("converting JSON schema to JSON value should never fail"),
213+
}
214+
}
215+
216+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
217+
let result = args.x + args.y;
218+
Ok(result)
219+
}
220+
}
221+
222+
#[derive(Deserialize, Serialize)]
223+
struct Subtract;
224+
impl Tool for Subtract {
225+
const NAME: &'static str = "subtract";
226+
227+
type Error = MathError;
228+
type Args = OperationArgs;
229+
type Output = i32;
230+
231+
async fn definition(&self, _prompt: String) -> ToolDefinition {
232+
ToolDefinition {
233+
name: "subtract".to_string(),
234+
description: "Subtract y from x (i.e.: x - y)".to_string(),
235+
parameters: serde_json::to_value(schema_for!(OperationArgs))
236+
.expect("converting JSON schema to JSON value should never fail"),
237+
}
238+
}
239+
240+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
241+
let result = args.x - args.y;
242+
Ok(result)
243+
}
244+
}
245+
246+
struct Multiply;
247+
impl Tool for Multiply {
248+
const NAME: &'static str = "multiply";
249+
250+
type Error = MathError;
251+
type Args = OperationArgs;
252+
type Output = i32;
253+
254+
async fn definition(&self, _prompt: String) -> ToolDefinition {
255+
ToolDefinition {
256+
name: "multiply".to_string(),
257+
description: "Compute the product of x and y (i.e.: x * y)".to_string(),
258+
parameters: serde_json::to_value(schema_for!(OperationArgs))
259+
.expect("converting JSON schema to JSON value should never fail"),
260+
}
261+
}
262+
263+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
264+
let result = args.x * args.y;
265+
Ok(result)
266+
}
267+
}
268+
269+
struct Divide;
270+
impl Tool for Divide {
271+
const NAME: &'static str = "divide";
272+
273+
type Error = MathError;
274+
type Args = OperationArgs;
275+
type Output = i32;
276+
277+
async fn definition(&self, _prompt: String) -> ToolDefinition {
278+
ToolDefinition {
279+
name: "divide".to_string(),
280+
description: "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios."
281+
.to_string(),
282+
parameters: serde_json::to_value(schema_for!(OperationArgs))
283+
.expect("converting JSON schema to JSON value should never fail"),
284+
}
285+
}
286+
287+
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
288+
let result = args.x / args.y;
289+
Ok(result)
290+
}
291+
}

0 commit comments

Comments
 (0)