Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ arrow-array = "55.2"
as-any = "0.3.2"
assert_fs = "1.1.3"
async-stream = "0.3.6"
async-trait = "0.1.88"
aws-config = "1.8.0"
aws-sdk-bedrockruntime = "1.95.0"
aws-smithy-types = "1.3.2"
Expand Down
38 changes: 22 additions & 16 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,30 @@ doctest = false
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
reqwest = { workspace = true, features = ["json", "stream", "multipart"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tracing = { workspace = true }
as-any = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true, optional = true }
base64 = { workspace = true }
bytes = { workspace = true }
epub = { workspace = true, optional = true }
futures = { workspace = true }
ordered-float = { workspace = true }
schemars = { workspace = true }
thiserror = { workspace = true }
rig-derive = { version = "0.1.4", path = "rig-core-derive", optional = true }
glob = { workspace = true }
lopdf = { workspace = true, optional = true }
epub = { workspace = true, optional = true }
mcp-core = { workspace = true, optional = true }
mime_guess = { workspace = true }
ordered-float = { workspace = true }
quick-xml = { workspace = true, optional = true }
rayon = { workspace = true, optional = true }
reqwest = { workspace = true, features = ["json", "stream", "multipart"] }
rig-derive = { version = "0.1.4", path = "rig-core-derive", optional = true }
schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }
worker = { workspace = true, optional = true }
mcp-core = { workspace = true, optional = true }
bytes = { workspace = true }
async-stream = { workspace = true }
mime_guess = { workspace = true }
base64 = { workspace = true }
as-any = { workspace = true }
rmcp = { version = "0.5", optional = true, features = ["client"] }
url = { workspace = true }
reqwest-eventsource = { workspace = true }

[dev-dependencies]
Expand Down Expand Up @@ -76,6 +77,7 @@ rayon = ["dep:rayon"]
worker = ["dep:worker"]
mcp = ["dep:mcp-core"]
rmcp = ["dep:rmcp"]
hooks = ["dep:async-trait"]
socks = ["reqwest/socks"]
# Replace "default-tls" with "rustls-tls" in "reqwest/default"
reqwest-rustls = [
Expand Down Expand Up @@ -148,3 +150,7 @@ required-features = ["derive"]
[[example]]
name = "rmcp"
required-features = ["rmcp"]

[[example]]
name = "request_hook"
required-features = ["hooks"]
96 changes: 96 additions & 0 deletions rig-core/examples/request_hook.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::env;

use rig::agent::PromptHook;
use rig::client::CompletionClient;
use rig::completion::{CompletionModel, CompletionResponse, Message, Prompt};
use rig::message::{AssistantContent, UserContent};
use rig::providers;

struct SessionIdHook<'a> {
session_id: &'a str,
}

#[async_trait::async_trait]
impl<'a, M: CompletionModel> PromptHook<M> for SessionIdHook<'a> {
async fn on_tool_call(&self, tool_name: &str, args: &str) {
println!(
"[Session {}] Calling tool: {} with args: {}",
self.session_id, tool_name, args
);
}
async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str) {
println!(
"[Session {}] Tool result for {} (args: {}): {}",
self.session_id, tool_name, args, result
);
}

async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) {
println!(
"[Session {}] Sending prompt: {}",
self.session_id,
match prompt {
Message::User { content } => content
.iter()
.filter_map(|c| {
if let UserContent::Text(text_content) = c {
Some(text_content.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n"),
Message::Assistant { content, .. } => content
.iter()
.filter_map(|c| if let AssistantContent::Text(text_content) = c {
Some(text_content.text.clone())
} else {
None
})
.collect::<Vec<_>>()
.join("\n"),
}
);
}

async fn on_completion_response(
&self,
_prompt: &Message,
response: &CompletionResponse<M::Response>,
) {
if let Ok(resp) = serde_json::to_string(&response.raw_response) {
println!("[Session {}] Received response: {}", self.session_id, resp);
} else {
println!(
"[Session {}] Received response: <non-serializable>",
self.session_id
);
}
}
}

// Example main function (pseudo-code, as actual Agent/CompletionModel setup is project-specific)
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let client = providers::openai::Client::new(
&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"),
);

// Create agent with a single context prompt
let comedian_agent = client
.agent("gpt-4o")
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

let session_id = "abc123";
let hook = SessionIdHook { session_id };

// Prompt the agent and print the response
comedian_agent
.prompt("Entertain me!")
.with_hook(&hook)
.await?;

Ok(())
}
2 changes: 2 additions & 0 deletions rig-core/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,6 @@ mod prompt_request;

pub use builder::AgentBuilder;
pub use completion::Agent;
#[cfg(feature = "hooks")]
pub use prompt_request::PromptHook;
pub use prompt_request::{PromptRequest, PromptResponse};
82 changes: 73 additions & 9 deletions rig-core/src/agent/prompt_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ pub struct PromptRequest<'a, S: PromptType, M: CompletionModel> {
agent: &'a Agent<M>,
/// Phantom data to track the type of the request
state: PhantomData<S>,
#[cfg(feature = "hooks")]
/// Optional per-request hook for events
hook: Option<&'a dyn PromptHook<M>>,
}

impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
Expand All @@ -49,6 +52,8 @@ impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
max_depth: 0,
agent,
state: PhantomData,
#[cfg(feature = "hooks")]
hook: None,
}
}

Expand All @@ -64,6 +69,8 @@ impl<'a, M: CompletionModel> PromptRequest<'a, Standard, M> {
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
#[cfg(feature = "hooks")]
hook: self.hook,
}
}
}
Expand All @@ -78,6 +85,8 @@ impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
max_depth: depth,
agent: self.agent,
state: PhantomData,
#[cfg(feature = "hooks")]
hook: self.hook,
}
}

Expand All @@ -89,8 +98,44 @@ impl<'a, S: PromptType, M: CompletionModel> PromptRequest<'a, S, M> {
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
#[cfg(feature = "hooks")]
hook: self.hook,
}
}

#[cfg(feature = "hooks")]
/// Attach a per-request hook for tool call events
pub fn with_hook(self, hook: &'a dyn PromptHook<M>) -> PromptRequest<'a, S, M> {
PromptRequest {
prompt: self.prompt,
chat_history: self.chat_history,
max_depth: self.max_depth,
agent: self.agent,
state: PhantomData,
#[cfg(feature = "hooks")]
hook: Some(hook),
}
}
}
#[cfg(feature = "hooks")]
/// Trait for per-request hooks to observe tool call events
#[async_trait::async_trait]
pub trait PromptHook<M: CompletionModel>: Send + Sync {
/// Called before the prompt is sent to the model
async fn on_completion_call(&self, prompt: &Message, history: &[Message]);

/// Called after the prompt is sent to the model and a response is received
async fn on_completion_response(
&self,
prompt: &Message,
response: &crate::completion::CompletionResponse<M::Response>,
);

/// Called before a tool is invoked
async fn on_tool_call(&self, tool_name: &str, args: &str);

/// Called after a tool is invoked
async fn on_tool_result(&self, tool_name: &str, args: &str, result: &str);
}

/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
Expand Down Expand Up @@ -149,7 +194,7 @@ impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
let mut current_max_depth = 0;
let mut usage = Usage::new();

// We need to do atleast 2 loops for 1 roundtrip (user expects normal message)
// We need to do at least 2 loops for 1 roundtrip (user expects normal message)
let last_prompt = loop {
let prompt = chat_history
.last()
Expand All @@ -170,14 +215,28 @@ impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
);
}

#[cfg(feature = "hooks")]
if let Some(hook) = self.hook.as_ref() {
hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1])
.await;
}

let resp = agent
.completion(prompt, chat_history[..chat_history.len() - 1].to_vec())
.completion(
prompt.clone(),
chat_history[..chat_history.len() - 1].to_vec(),
)
.await?
.send()
.await?;

usage += resp.usage;

#[cfg(feature = "hooks")]
if let Some(hook) = self.hook.as_ref() {
hook.on_completion_response(&prompt, &resp).await;
}

let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
Expand Down Expand Up @@ -212,13 +271,18 @@ impl<M: CompletionModel> PromptRequest<'_, Extended, M> {
let tool_content = stream::iter(tool_calls)
.then(|choice| async move {
if let AssistantContent::ToolCall(tool_call) = choice {
let output = agent
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;
let tool_name = &tool_call.function.name;
let args = tool_call.function.arguments.to_string();
#[cfg(feature = "hooks")]
if let Some(hook) = self.hook.as_ref() {
hook.on_tool_call(tool_name, &args).await;
}
let output = agent.tools.call(tool_name, args.clone()).await?;
#[cfg(feature = "hooks")]
if let Some(hook) = self.hook.as_ref() {
hook.on_tool_result(tool_name, &args, &output.to_string())
.await;
}
if let Some(call_id) = tool_call.call_id.clone() {
Ok(UserContent::tool_result_with_call_id(
tool_call.id.clone(),
Expand Down