diff --git a/rig-core/examples/agent_with_inception.rs b/rig-core/examples/agent_with_inception.rs new file mode 100644 index 000000000..c7402950c --- /dev/null +++ b/rig-core/examples/agent_with_inception.rs @@ -0,0 +1,27 @@ +use std::env; + +use rig::{ + completion::Prompt, + providers::inception::{ClientBuilder, MERCURY_CODER_SMALL}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create Inception Labs client + let client = + ClientBuilder::new(&env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set")) + .build(); + + // Create agent with a single context prompt + let agent = client + .agent(MERCURY_CODER_SMALL) + .preamble("You are a helpful AI assistant.") + .temperature(0.0) + .build(); + + // Prompt the agent and print the response + let response = agent.prompt("Hello, how are you?").await?; + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/examples/inception_streaming.rs b/rig-core/examples/inception_streaming.rs new file mode 100644 index 000000000..1e63f74e4 --- /dev/null +++ b/rig-core/examples/inception_streaming.rs @@ -0,0 +1,22 @@ +use rig::{ + providers::inception::{self, completion::MERCURY_CODER_SMALL}, + streaming::{stream_to_stdout, StreamingPrompt}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create streaming agent with a single context prompt + let agent = inception::Client::from_env() + .agent(MERCURY_CODER_SMALL) + .preamble("Be precise and concise.") + .build(); + + // Stream the response and print chunks as they arrive + let mut stream = agent + .stream_prompt("When and where and what type is the next solar eclipse?") + .await?; + + stream_to_stdout(agent, &mut stream).await?; + + Ok(()) +} diff --git a/rig-core/src/providers/inception/client.rs b/rig-core/src/providers/inception/client.rs new file mode 100644 index 000000000..7871b0896 --- /dev/null +++ b/rig-core/src/providers/inception/client.rs @@ -0,0 +1,91 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{ + agent::AgentBuilder, extractor::ExtractorBuilder, + providers::inception::completion::CompletionModel, +}; + +const INCEPTION_API_BASE_URL: &str = "https://api.inceptionlabs.ai/v1"; + +#[derive(Clone)] +pub struct ClientBuilder<'a> { + api_key: &'a str, + base_url: &'a str, +} + +impl<'a> ClientBuilder<'a> { + pub fn new(api_key: &'a str) -> Self { + Self { + api_key, + base_url: INCEPTION_API_BASE_URL, + } + } + + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn build(self) -> Client { + Client::new(self.api_key, self.base_url) + } +} + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Content-Type", + "application/json" + .parse() + .expect("Content-Type should parse"), + ); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Authorization should parse"), + ); + headers + }) + .build() + .expect("Inception reqwest client should build"), + } + } + + pub fn from_env() -> Self { + let api_key = std::env::var("INCEPTION_API_KEY").expect("INCEPTION_API_KEY not set"); + ClientBuilder::new(&api_key).build() + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} diff --git a/rig-core/src/providers/inception/completion.rs b/rig-core/src/providers/inception/completion.rs new file mode 100644 index 000000000..18f19dd80 --- /dev/null +++ b/rig-core/src/providers/inception/completion.rs @@ -0,0 +1,197 @@ +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::{ + completion::{self, CompletionError}, + message::{self, MessageError}, + OneOrMany, +}; + +use super::client::Client; + +// ================================================================ +// Inception Completion API +// ================================================================ +/// `mercury-coder-small` completion model +pub const MERCURY_CODER_SMALL: &str = "mercury-coder-small"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub object: String, + pub created: u64, + pub model: String, + pub usage: Usage, +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +impl From for completion::AssistantContent { + fn from(choice: Choice) -> Self { + completion::AssistantContent::from(&choice) + } +} + +impl From<&Choice> for completion::AssistantContent { + fn from(choice: &Choice) -> Self { + completion::AssistantContent::Text(completion::message::Text { + text: choice.message.content.clone(), + }) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Message { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + self.prompt_tokens, self.completion_tokens, self.total_tokens + ) + } +} + +impl TryFrom for Message { + type Error = MessageError; + + fn try_from(message: message::Message) -> Result { + Ok(match message { + message::Message::User { content } => Message { + role: Role::User, + content: match content.first() { + message::UserContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "User message content must be a text message".to_string(), + )) + } + }, + }, + message::Message::Assistant { content } => Message { + role: Role::Assistant, + content: match content.first() { + message::AssistantContent::Text(message::Text { text }) => text.clone(), + _ => { + return Err(MessageError::ConversionError( + "Assistant message content must be a text message".to_string(), + )) + } + }, + }, + }) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: CompletionResponse) -> Result { + let content = response.choices.iter().map(Into::into).collect::>(); + + let choice = OneOrMany::many(content).map_err(|_| { + CompletionError::ResponseError( + "Response contained no message or tool call (empty)".to_owned(), + ) + })?; + + Ok(completion::CompletionResponse { + choice, + raw_response: response, + }) + } +} + +const MAX_TOKENS: u64 = 8192; + +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let max_tokens = completion_request.max_tokens.unwrap_or(MAX_TOKENS); + + let prompt_message: Message = completion_request + .prompt_with_context() + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; + + let mut messages = completion_request + .chat_history + .into_iter() + .map(|message| { + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into())) + }) + .collect::, _>>()?; + + messages.push(prompt_message); + + let request = json!({ + "model": self.model, + "messages": messages, + "max_tokens": max_tokens, + }); + + let response = self + .client + .post("/chat/completions") + .json(&request) + .send() + .await?; + + if response.status().is_success() { + let response = response.json::().await?; + tracing::info!(target: "rig", + "Inception completion token usage: {}", + response.usage + ); + Ok(response.try_into()?) + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} diff --git a/rig-core/src/providers/inception/mod.rs b/rig-core/src/providers/inception/mod.rs new file mode 100644 index 000000000..08fec8529 --- /dev/null +++ b/rig-core/src/providers/inception/mod.rs @@ -0,0 +1,6 @@ +pub mod client; +pub mod completion; +pub mod streaming; + +pub use client::{Client, ClientBuilder}; +pub use completion::MERCURY_CODER_SMALL; diff --git a/rig-core/src/providers/inception/streaming.rs b/rig-core/src/providers/inception/streaming.rs new file mode 100644 index 000000000..ed3eb0a19 --- /dev/null +++ b/rig-core/src/providers/inception/streaming.rs @@ -0,0 +1,121 @@ +use async_stream::stream; +use futures::StreamExt; +use serde::Deserialize; +use serde_json::json; + +use super::completion::{CompletionModel, Message}; +use crate::completion::{CompletionError, CompletionRequest}; +use crate::json_utils::merge_inplace; +use crate::message::MessageError; +use crate::providers::anthropic::decoders::sse::from_response as sse_from_response; +use crate::streaming::{self, StreamingCompletionModel, StreamingResult}; + +#[derive(Debug, Deserialize)] +pub struct StreamingResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct StreamingChoice { + pub index: usize, + pub delta: Delta, + pub finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Delta { + pub content: Option, + pub role: Option, +} + +impl StreamingCompletionModel for CompletionModel { + async fn stream( + &self, + completion_request: CompletionRequest, + ) -> Result { + let prompt_message: Message = completion_request + .prompt_with_context() + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?; + + let mut messages = completion_request + .chat_history + .into_iter() + .map(|message| { + message + .try_into() + .map_err(|e: MessageError| CompletionError::RequestError(e.into())) + }) + .collect::, _>>()?; + + messages.push(prompt_message); + + let mut request = json!({ + "model": self.model, + "messages": messages, + "max_tokens": completion_request.max_tokens.unwrap_or(8192), + "stream": true, + }); + + if let Some(temperature) = completion_request.temperature { + merge_inplace(&mut request, json!({ "temperature": temperature })); + } + + if let Some(ref params) = completion_request.additional_params { + merge_inplace(&mut request, params.clone()) + } + + let response = self + .client + .post("chat/completions") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError(response.text().await?)); + } + + // Use our SSE decoder to directly handle Server-Sent Events format + let sse_stream = sse_from_response(response); + + Ok(Box::pin(stream! { + let mut sse_stream = Box::pin(sse_stream); + + while let Some(sse_result) = sse_stream.next().await { + match sse_result { + Ok(sse) => { + // Parse the SSE data as a StreamingResponse + match serde_json::from_str::(&sse.data) { + Ok(response) => { + if let Some(choice) = response.choices.first() { + if let Some(content) = &choice.delta.content { + yield Ok(streaming::StreamingChoice::Message(content.clone())); + } + if choice.finish_reason.as_deref() == Some("stop") { + break; + } + } + }, + Err(e) => { + if !sse.data.trim().is_empty() { + yield Err(CompletionError::ResponseError( + format!("Failed to parse JSON: {} (Data: {})", e, sse.data) + )); + } + } + } + }, + Err(e) => { + yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e))); + break; + } + } + } + })) + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 99f7a946a..fe2ee2e8e 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -11,6 +11,7 @@ //! - DeepSeek //! - Azure OpenAI //! - Mira +//! - Inception //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -54,6 +55,7 @@ pub mod gemini; pub mod groq; pub mod huggingface; pub mod hyperbolic; +pub mod inception; pub mod mira; pub mod moonshot; pub mod ollama;