From 39e640d22c845a3e47c1c61d9422df4b57b7889c Mon Sep 17 00:00:00 2001 From: rajiknows Date: Thu, 31 Jul 2025 10:14:54 +0530 Subject: [PATCH 1/8] feat: initialize foundry.rs --- rig-core/src/providers/foundry.rs | 337 ++++++++++++++++++++++++++++++ rig-core/src/providers/mod.rs | 1 + 2 files changed, 338 insertions(+) create mode 100644 rig-core/src/providers/foundry.rs diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry.rs new file mode 100644 index 000000000..e6cd08fbc --- /dev/null +++ b/rig-core/src/providers/foundry.rs @@ -0,0 +1,337 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +use crate::{ + Embed, OneOrMany, + client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, + completion::{self, CompletionError, CompletionRequest, Usage}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + impl_conversion_traits, +}; + +const FOUNDRY_API_BASE_URL: &str = "http://localhost:8080"; + +pub struct ClientBuilder<'a> { + base_url: &'a str, + http_client: Option, +} + +impl<'a> ClientBuilder<'a> { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + base_url: FOUNDRY_API_BASE_URL, + http_client: None, + } + } + + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn custom_client(mut self, client: reqwest::Client) -> Self { + self.http_client = Some(client); + self + } + + pub fn build(self) -> Result { + let http_client = if let Some(http_client) = self.http_client { + http_client + } else { + reqwest::Client::builder().build()? + }; + + Ok(Client { + base_url: self.base_url.to_string(), + http_client, + }) + } +} + +#[derive(Clone, Debug)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Default for Client { + fn default() -> Self { + Self::new() + } +} + +impl Client { + /// Create a new Ollama client builder. + /// + /// # Example + /// ``` + /// use rig::providers::ollama::{ClientBuilder, self}; + /// + /// // Initialize the Ollama client + /// let client = Client::builder() + /// .build() + /// ``` + pub fn builder() -> ClientBuilder<'static> { + ClientBuilder::new() + } + + /// Create a new Ollama client. For more control, use the `builder` method. + /// + /// # Panics + /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). + pub fn new() -> Self { + Self::builder().build().expect("Ollama client should build") + } + + pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path); + self.http_client.post(url) + } +} + +impl ProviderClient for Client { + fn from_env() -> Self + where + Self: Sized, + { + let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set"); + Self::builder().base_url(&api_base).build().unwrap() + } + + fn from_val(input: crate::client::ProviderValue) -> Self { + let crate::client::ProviderValue::Simple(_) = input else { + panic!("Incorrect provider value type") + }; + + Self::new() + } +} + +impl CompletionClient for Client { + type CompletionModel = CompletionModel; + + fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } +} + +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; + fn embedding_model(&self, model: &str) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, 0) + } + fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, ndims) + } + fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) + } +} + +impl_conversion_traits!( + AsTranscription, + AsImageGeneration, + AsAudioGeneration for Client +); + +// ---------- API Error and Response Structures ---------- + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +// ---------- Embedding API ---------- + +/// TODO: mention the commpletion models here + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddingData { + object: String, + embedding: Vec, + index: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddingResponse { + object: String, + data: Vec, + model: String, + usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Usage { + prompt_tokens: u64, + total_tokens: u64, +} + +impl From for EmbeddingError { + fn from(err: ApiErrorResponse) -> Self { + EmbeddingError::ProviderError(err.message) + } +} + +impl From> for Result { + fn from(value: ApiResponse) -> Self { + match value { + ApiResponse::Ok(response) => Ok(response), + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +// ----------- Embedding Model -------------- + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + pub model: String, + ndims: usize, +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { + Self { + client, + model: model.to_owned(), + ndims, + } + } +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + fn ndims(&self) -> usize { + self.ndims + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn embed_texts( + &self, + documents: impl IntoIterator + Send, + ) -> Result, EmbeddingError> { + let docs: Vec = documents.into_iter().collect(); + let payload = json!({ + "model": self.model, + "input":docs, + }); + let response = self + .client + .post("v1/embeddings") + .json(&payload) + .send() + .await + .map_err(|e| EmbeddingError::ResponseError(e.to_string()))?; + if response.status().is_success() { + let api_resp: EmbeddingResponse = response + .json() + .await + .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?; + + if api_resp.data.len() != docs.len() { + return Err(EmbeddingError::ResponseError( + "Number of returned embeddings does not match input".into(), + )); + } + Ok(api_resp + .data + .into_iter() + .zip(docs.into_iter()) + .map(|(vec, document)| embeddings::Embedding { document, vec }) + .collect()) + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) + } + } +} + +// ----------- Completions API ------------- + +// TODO: add models here + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct CompletionsUsage { + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct Choice { + pub index: u64, + pub message: CompletionMessage, + pub finish_reason: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct CompletionMessage { + pub role: String, + pub content: String, +} + +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: String, + pub model: String, + pub choices: Vec, + pub usage: CompletionsUsage, +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + fn try_from(resp: CompletionResponse) -> Result { + let mut assitant_contents = Vec::new(); + + // foundry only responds with an array of choices which have + // role and content (role is always "assistant" for responses) + for choice in resp.choices.clone() { + assitant_contents.push(completion::AssistantContent::text(&choice.message.content)); + } + + let choice = OneOrMany::many(assitant_contents) + .map_err(|_| CompletionError::ResponseError("No content provided".to_owned()))?; + + Ok(completion::CompletionResponse { + choice, + usage: rig::completion::request::Usage { + input_tokens: resp.usage.prompt_tokens, + output_tokens: resp.usage.completion_tokens, + total_tokens: resp.usage.total_tokens, + }, + raw_response: resp, + }) + } +} + +// ----------- Completion Model ---------- + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_owned(), + } + } + + fn create_completion_request( + &self, + completion_request: CompletionRequest, + ) -> Result { + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 53ca700f5..1d5976346 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -49,6 +49,7 @@ pub mod anthropic; pub mod azure; pub mod cohere; pub mod deepseek; +pub mod foundry; pub mod galadriel; pub mod gemini; pub mod groq; From abee08be1493008c5ac722c89ae8fd4c51d64ebf Mon Sep 17 00:00:00 2001 From: rajiknows Date: Thu, 31 Jul 2025 12:58:53 +0530 Subject: [PATCH 2/8] feat: foundry embeddings api --- rig-core/src/providers/foundry.rs | 49 +++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry.rs index e6cd08fbc..598b85efb 100644 --- a/rig-core/src/providers/foundry.rs +++ b/rig-core/src/providers/foundry.rs @@ -2,11 +2,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use crate::{ - Embed, OneOrMany, - client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, - completion::{self, CompletionError, CompletionRequest, Usage}, - embeddings::{self, EmbeddingError, EmbeddingsBuilder}, - impl_conversion_traits, + client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, completion::{self, CompletionError, CompletionRequest, ToolDefinition, Usage}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, impl_conversion_traits, json_utils, Embed, OneOrMany }; const FOUNDRY_API_BASE_URL: &str = "http://localhost:8080"; @@ -333,5 +329,48 @@ impl CompletionModel { &self, completion_request: CompletionRequest, ) -> Result { + let mut partial_history = vec![]; + if let Some(docs) = completion_request.normalized_documents(){ + partial_history.push(docs); + } + partial_history.extend(completion_request.chat_history); + + let mut full_history = completion_request.preamble.map_or_else(Vec::new, |preamble| vec![CompletionMessage::from(&preamble)]); + + // convert and extend the rest of the history + full_history.extend( + partial_history + .into_iter() + .map(|msg| msg.try_into()) + .collect::>,_>>()? + .into_iter() + .flatten() + .collect::>(); + ); + + let mut requeest_payload = json!({ + "model": self.model, + "messages": full_history, + "temparature": completion_request.temperature, + "stream": false, + }); + + if !completion_request.tools.is_empty(){ + // Foundry's functions have same structure as completion::ToolDefination + requeest_payload["functions"] = json!( + completion_request + .tools + .into_iter() + .map(|tool| tool.into()) + .collect::>() + ); + } + + tracing::debug!(target: "rig", "Chat mode payload: {}", requeest_payload); + + Ok(requeest_payload) } } + +// ---------- CompletionModel Implementation ---------- +// From 13a8ae788171303397ee3903e6441b23c5f6709f Mon Sep 17 00:00:00 2001 From: rajiknows Date: Fri, 1 Aug 2025 13:36:00 +0530 Subject: [PATCH 3/8] feat: foundry integration --- rig-core/src/providers/foundry.rs | 467 ++++++++++++++++++++++++++++-- 1 file changed, 443 insertions(+), 24 deletions(-) diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry.rs index 598b85efb..de616079a 100644 --- a/rig-core/src/providers/foundry.rs +++ b/rig-core/src/providers/foundry.rs @@ -1,11 +1,20 @@ +use async_stream::stream; +use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use std::{convert::TryFrom, str::FromStr}; use crate::{ - client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, completion::{self, CompletionError, CompletionRequest, ToolDefinition, Usage}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, impl_conversion_traits, json_utils, Embed, OneOrMany + Embed, OneOrMany, + client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, + completion::{self, CompletionError, CompletionRequest, ToolDefinition}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + impl_conversion_traits, json_utils, + message::Text, + streaming::{self, RawStreamingChoice}, }; -const FOUNDRY_API_BASE_URL: &str = "http://localhost:8080"; +const FOUNDRY_API_BASE_URL: &str = "http://localhost:42069"; pub struct ClientBuilder<'a> { base_url: &'a str, @@ -58,13 +67,13 @@ impl Default for Client { } impl Client { - /// Create a new Ollama client builder. + /// Create a new Foundry client builder. /// /// # Example /// ``` - /// use rig::providers::ollama::{ClientBuilder, self}; + /// use rig::providers::foundry::{ClientBuilder, self}; /// - /// // Initialize the Ollama client + /// // Initialize the Foundry client /// let client = Client::builder() /// .build() /// ``` @@ -72,7 +81,7 @@ impl Client { ClientBuilder::new() } - /// Create a new Ollama client. For more control, use the `builder` method. + /// Create a new Foundry client. For more control, use the `builder` method. /// /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). @@ -145,10 +154,12 @@ enum ApiResponse { Err(ApiErrorResponse), } -// ---------- Embedding API ---------- - -/// TODO: mention the commpletion models here +pub const COHERE_EMBED_V4_0: &str = "embed-v-4-0"; +pub const COHERE_EMBED_V3_ENGLISH: &str = "Cohere-embed-v3-english"; +pub const COHERE_EMBED_V3_MULTILINGUAL: &str = "Cohere-embed-v3-multilingual"; +pub const OPENAI_TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; +// ---------- Embedding API ---------- #[derive(Debug, Serialize, Deserialize)] struct EmbeddingData { object: String, @@ -186,7 +197,6 @@ impl From> for Result Result { let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents(){ + if let Some(docs) = completion_request.normalized_documents() { partial_history.push(docs); } partial_history.extend(completion_request.chat_history); - let mut full_history = completion_request.preamble.map_or_else(Vec::new, |preamble| vec![CompletionMessage::from(&preamble)]); + let mut full_history = completion_request + .preamble + .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); // convert and extend the rest of the history full_history.extend( partial_history .into_iter() .map(|msg| msg.try_into()) - .collect::>,_>>()? + .collect::>, _>>()? .into_iter() .flatten() - .collect::>(); + .collect::>(), ); let mut requeest_payload = json!({ @@ -355,7 +390,7 @@ impl CompletionModel { "stream": false, }); - if !completion_request.tools.is_empty(){ + if !completion_request.tools.is_empty() { // Foundry's functions have same structure as completion::ToolDefination requeest_payload["functions"] = json!( completion_request @@ -372,5 +407,389 @@ impl CompletionModel { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct StreamingCompletionResponse { + pub id: String, + pub object: String, + pub created: String, + pub model: String, + pub usage: CompletionsUsage, +} + // ---------- CompletionModel Implementation ---------- -// +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + type StreamingResponse = StreamingCompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let request_payload = self.create_completion_request(completion_request)?; + + let response = self + .client + .post("/v1/chat/completions") + .json(&request_payload) + .send() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + if response.status().is_success() { + let text = response + .text() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + tracing::debug!(target: "rig", "Foundry chat response: {}", text); + let chat_resp: CompletionResponse = serde_json::from_str(&text) + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let conv: completion::CompletionResponse = chat_resp.try_into()?; + Ok(conv) + } else { + let err_text = response + .text() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + Err(CompletionError::ProviderError(err_text)) + } + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn stream( + &self, + request: CompletionRequest, + ) -> Result< + crate::streaming::StreamingCompletionResponse, + CompletionError, + > { + let mut request_payload = self.create_completion_request(request)?; + json_utils::merge_inplace(&mut request_payload, json!({"stream": true})); + + let response = self + .client + .post("/v1/chat/completions") + .json(&request_payload) + .send() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let err_text = response + .text() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + return Err(CompletionError::ProviderError(err_text)); + } + + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + while let Some(chunk_result) = stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(CompletionError::from(e)); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + yield Err(CompletionError::ResponseError(e.to_string())); + break; + } + }; + + for line in text.lines() { + let line = line.trim(); + + if line.is_empty() { + continue; + } + + let data_line = if line.starts_with("data: ") { + &line[6..] + } else { + line + }; + + // stream termination like openai + if data_line == "[DONE]" { + break; + } + + let Ok(response) = serde_json::from_str::(data_line) else { + continue; + }; + + for choice in response.choices.iter() { + if !choice.message.content.is_empty() { + yield Ok(RawStreamingChoice::Message(choice.message.content.clone())); + } + } + if response.choices.iter().any(|choice| choice.finish_reason == "stop") { + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + id: response.id.clone(), + object: response.object.clone(), + created: response.created.clone(), + model: response.model.clone(), + usage: response.usage.clone(), + })); + } + } + } + }); + + Ok(streaming::StreamingCompletionResponse::stream(stream)) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum Role { + #[serde(rename = "user")] + User, + #[serde(rename = "system")] + System, + #[serde(rename = "assistant")] + Assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + role: Role, + content: String, +} + +impl TryFrom for Vec { + type Error = crate::message::MessageError; + fn try_from(internal_msg: crate::message::Message) -> Result { + use crate::message::Message as InternalMessage; + match internal_msg { + InternalMessage::User { content, .. } => { + // Foundry doesn't support tool results in messages, so we skip them + let non_tool_content: Vec<_> = content + .into_iter() + .filter(|content| { + !matches!(content, crate::message::UserContent::ToolResult(_)) + }) + .collect(); + + let text_contents: Vec = non_tool_content + .into_iter() + .filter_map(|content| match content { + crate::message::UserContent::Text(crate::message::Text { text }) => { + Some(text) + } + _ => None, + }) + .collect(); + + Ok(vec![Message { + role: Role::User, + content: text_contents.join(" "), + }]) + } + InternalMessage::Assistant { content, .. } => { + let text_contents: Vec = content + .into_iter() + .filter_map(|content| match content { + crate::message::AssistantContent::Text(text) => Some(text.text), + _ => None, + }) + .collect(); + + Ok(vec![Message { + role: Role::Assistant, + content: text_contents.join(" "), + }]) + } + } + } +} + +impl From for crate::completion::Message { + fn from(msg: Message) -> Self { + match msg.role { + Role::User => crate::completion::Message::User { + content: OneOrMany::one(crate::completion::message::UserContent::Text(Text { + text: msg.content, + })), + }, + Role::Assistant => crate::completion::Message::Assistant { + id: None, + content: OneOrMany::one(crate::completion::message::AssistantContent::Text({ + Text { text: msg.content } + })), + }, + Role::System => crate::completion::Message::User { + content: OneOrMany::one(crate::completion::message::UserContent::Text(Text { + text: msg.content, + })), + }, + } + } +} + +impl Message { + pub fn system(content: &str) -> Self { + Self { + role: Role::System, + content: content.to_owned(), + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct SystemContent { + #[serde(default)] + r#type: SystemContentType, + text: String, +} + +#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum SystemContentType { + #[default] + Text, +} + +impl From for SystemContent { + fn from(s: String) -> Self { + SystemContent { + r#type: SystemContentType::default(), + text: s, + } + } +} + +impl FromStr for SystemContent { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> Result { + Ok(SystemContent { + r#type: SystemContentType::default(), + text: s.to_string(), + }) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AssistantContent { + pub text: String, +} + +impl FromStr for AssistantContent { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> Result { + Ok(AssistantContent { text: s.to_owned() }) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum UserContent { + Text { text: String }, +} + +impl FromStr for UserContent { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> Result { + Ok(UserContent::Text { text: s.to_owned() }) + } +} + +// ================================================================= +// Tests +// ================================================================= + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn test_chat_completion() { + let sample_chat_response = json!({ + "id": "chatcmpl-1234567890", + "object": "chat.completion", + "created": "1677851234", + "model": "Phi-4-mini-instruct-generic-cpu", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The sky is blue because of Rayleigh scattering." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + }); + let sample_text = sample_chat_response.to_string(); + + let chat_resp: CompletionResponse = + serde_json::from_str(&sample_text).expect("Invalid JSON structure"); + let conv: completion::CompletionResponse = + chat_resp.try_into().unwrap(); + assert!( + !conv.choice.is_empty(), + "Expected non-empty choice in chat response" + ); + } + + #[test] + fn test_message_conversion() { + let provider_msg = Message { + role: Role::User, + content: "Test message".to_owned(), + }; + let comp_msg: crate::completion::Message = provider_msg.into(); + match comp_msg { + crate::completion::Message::User { content } => { + let first_content = content.first(); + match first_content { + crate::completion::message::UserContent::Text(text_struct) => { + assert_eq!(text_struct.text, "Test message"); + } + _ => panic!("Expected text content in conversion"), + } + } + _ => panic!("Conversion from provider Message to completion Message failed"), + } + } + + #[test] + fn test_system_content_from_string() { + let content = SystemContent::from("Test system message".to_string()); + assert_eq!(content.text, "Test system message"); + assert!(matches!(content.r#type, SystemContentType::Text)); + } + + #[test] + fn test_system_content_from_str() { + let content: SystemContent = "Test system message".parse().unwrap(); + assert_eq!(content.text, "Test system message"); + assert!(matches!(content.r#type, SystemContentType::Text)); + } + + #[test] + fn test_assistant_content_from_str() { + let content: AssistantContent = "Test assistant message".parse().unwrap(); + assert_eq!(content.text, "Test assistant message"); + } + + #[test] + fn test_user_content_from_str() { + let content: UserContent = "Test user message".parse().unwrap(); + match content { + UserContent::Text { text } => { + assert_eq!(text, "Test user message"); + } + } + } +} From 85333472c14d0bd9b75d52c9a9f943cc4eb2e3c6 Mon Sep 17 00:00:00 2001 From: rajiknows Date: Fri, 1 Aug 2025 14:12:17 +0530 Subject: [PATCH 4/8] fix: clippy --- rig-core/src/providers/foundry.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry.rs index de616079a..4034745dc 100644 --- a/rig-core/src/providers/foundry.rs +++ b/rig-core/src/providers/foundry.rs @@ -396,7 +396,7 @@ impl CompletionModel { completion_request .tools .into_iter() - .map(|tool| tool.into()) + .map(|tool| tool) .collect::>() ); } @@ -508,9 +508,9 @@ impl completion::CompletionModel for CompletionModel { continue; } - let data_line = if line.starts_with("data: ") { - &line[6..] - } else { + let data_line = if let Some(data) = line.strip_prefix("data: "){ + data + }else{ line }; From 6cd005ec07da31e5ece77f34eb723e7bd4e2d3ce Mon Sep 17 00:00:00 2001 From: rajiknows Date: Fri, 1 Aug 2025 14:15:43 +0530 Subject: [PATCH 5/8] fix: dumb clippy error again --- rig-core/src/providers/foundry.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry.rs index 4034745dc..510028224 100644 --- a/rig-core/src/providers/foundry.rs +++ b/rig-core/src/providers/foundry.rs @@ -396,7 +396,6 @@ impl CompletionModel { completion_request .tools .into_iter() - .map(|tool| tool) .collect::>() ); } From 6bafb90e36fa5dd37641f259824f5d2305a5d2b8 Mon Sep 17 00:00:00 2001 From: raj Date: Sun, 17 Aug 2025 12:50:40 +0530 Subject: [PATCH 6/8] fix: nit and tool calls --- .../{foundry.rs => foundry_local.rs} | 548 +++++++++++++----- rig-core/src/providers/mod.rs | 2 +- 2 files changed, 394 insertions(+), 156 deletions(-) rename rig-core/src/providers/{foundry.rs => foundry_local.rs} (54%) diff --git a/rig-core/src/providers/foundry.rs b/rig-core/src/providers/foundry_local.rs similarity index 54% rename from rig-core/src/providers/foundry.rs rename to rig-core/src/providers/foundry_local.rs index 510028224..884c4e375 100644 --- a/rig-core/src/providers/foundry.rs +++ b/rig-core/src/providers/foundry_local.rs @@ -2,15 +2,15 @@ use async_stream::stream; use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; -use std::{convert::TryFrom, str::FromStr}; +use std::{collections::HashMap, convert::TryFrom, str::FromStr}; use crate::{ Embed, OneOrMany, client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient}, - completion::{self, CompletionError, CompletionRequest, ToolDefinition}, + completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, impl_conversion_traits, json_utils, - message::Text, + message::{self, Text}, streaming::{self, RawStreamingChoice}, }; @@ -67,7 +67,7 @@ impl Default for Client { } impl Client { - /// Create a new Foundry client builder. + /// Create a new Foundry-Local client builder. /// /// # Example /// ``` @@ -75,18 +75,20 @@ impl Client { /// /// // Initialize the Foundry client /// let client = Client::builder() - /// .build() + /// .build() /// ``` pub fn builder() -> ClientBuilder<'static> { ClientBuilder::new() } - /// Create a new Foundry client. For more control, use the `builder` method. + /// Create a new Foundry-Local client. For more control, use the `builder` method. /// /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new() -> Self { - Self::builder().build().expect("Ollama client should build") + Self::builder() + .build() + .expect("Foundry-local client should build") } pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { @@ -100,7 +102,8 @@ impl ProviderClient for Client { where Self: Sized, { - let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set"); + let api_base = std::env::var("FOUNDRY_LOCAL_API_BASE_URL") + .expect("FOUNDRY_LOCAL_API_BASE_URL not set"); Self::builder().base_url(&api_base).build().unwrap() } @@ -263,7 +266,6 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } -// these i took from gemini ( review needed josh) pub const COHERE_COMMAND_A: &str = "Cohere-command-a"; pub const COHERE_COMMAND_R_PLUS: &str = "Cohere-command-r-plus-08-2024"; pub const COHERE_COMMAND_R: &str = "Cohere-command-r-08-2024"; @@ -286,6 +288,37 @@ pub const MICROSOFT_PHI_3_SMALL_8K_INSTRUCT: &str = "Phi-3-small-8k-instruct"; pub const MICROSOFT_PHI_3_SMALL_128K_INSTRUCT: &str = "Phi-3-small-128k-instruct"; // ----------- Completions API ------------- + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolDefinition { + #[serde(rename = "type")] + pub type_field: String, + pub function: crate::completion::ToolDefinition, +} + +impl From for ToolDefinition { + fn from(tool: crate::completion::ToolDefinition) -> Self { + ToolDefinition { + type_field: "function".to_owned(), + function: tool, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub r#type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct CompletionsUsage { pub prompt_tokens: u64, @@ -297,20 +330,24 @@ pub struct CompletionsUsage { pub struct Choice { pub index: u64, pub message: CompletionMessage, - pub finish_reason: String, + pub finish_reason: Option, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct CompletionMessage { pub role: String, - pub content: String, + // Content can be null when tool_calls are present + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } #[derive(Debug, Serialize, Deserialize)] pub struct CompletionResponse { pub id: String, pub object: String, - pub created: String, + pub created: u64, pub model: String, pub choices: Vec, pub usage: CompletionsUsage, @@ -319,20 +356,38 @@ pub struct CompletionResponse { impl TryFrom for completion::CompletionResponse { type Error = CompletionError; fn try_from(resp: CompletionResponse) -> Result { - let mut assitant_contents = Vec::new(); - - // foundry only responds with an array of choices which have - // role and content (role is always "assistant" for responses) - for choice in resp.choices.clone() { - assitant_contents.push(completion::AssistantContent::text(&choice.message.content)); - } + let choice = resp + .choices + .first() + .ok_or_else(|| CompletionError::ResponseError("No choices in response".to_owned()))?; + + let assistant_contents = if let Some(tool_calls) = &choice.message.tool_calls { + tool_calls + .iter() + .map(|tc| { + let arguments: Value = serde_json::from_str(&tc.function.arguments) + .map_err(|e| CompletionError::ResponseError(e.to_string()))?; + Ok(completion::AssistantContent::tool_call( + tc.id.clone(), + tc.function.name.clone(), + arguments, + )) + }) + .collect::, CompletionError>>()? + } else if let Some(content) = &choice.message.content { + vec![completion::AssistantContent::text(content)] + } else { + return Err(CompletionError::ResponseError( + "Response has neither content nor tool calls".to_owned(), + )); + }; - let choice = OneOrMany::many(assitant_contents) + let choice = OneOrMany::many(assistant_contents) .map_err(|_| CompletionError::ResponseError("No content provided".to_owned()))?; Ok(completion::CompletionResponse { choice, - usage: rig::completion::request::Usage { + usage: rig::completion::Usage { input_tokens: resp.usage.prompt_tokens, output_tokens: resp.usage.completion_tokens, total_tokens: resp.usage.total_tokens, @@ -383,42 +438,86 @@ impl CompletionModel { .collect::>(), ); - let mut requeest_payload = json!({ + let mut request_payload = json!({ "model": self.model, "messages": full_history, - "temparature": completion_request.temperature, + "temperature": completion_request.temperature, "stream": false, }); if !completion_request.tools.is_empty() { - // Foundry's functions have same structure as completion::ToolDefination - requeest_payload["functions"] = json!( + request_payload["tools"] = json!( completion_request .tools .into_iter() + .map(|tool| tool.into()) .collect::>() ); } - tracing::debug!(target: "rig", "Chat mode payload: {}", requeest_payload); + tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload); - Ok(requeest_payload) + Ok(request_payload) } } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct StreamingCompletionResponse { +// Changed StreamingCompletionResponse to handle SSE deltas +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamingCompletionResponseChunk { pub id: String, pub object: String, - pub created: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamingChoice { + pub index: u64, + pub delta: DeltaMessage, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct DeltaMessage { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct StreamingToolCall { + pub index: u64, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub r#type: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct StreamingFunctionCall { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// Final response for streaming mode +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct StreamingFinalResponse { + pub id: String, pub model: String, - pub usage: CompletionsUsage, } // ---------- CompletionModel Implementation ---------- impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; - type StreamingResponse = StreamingCompletionResponse; + type StreamingResponse = StreamingFinalResponse; #[cfg_attr(feature = "worker", worker::send)] async fn completion( @@ -440,9 +539,9 @@ impl completion::CompletionModel for CompletionModel { .text() .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; - tracing::debug!(target: "rig", "Foundry chat response: {}", text); + tracing::debug!(target: "rig", "Foundry-Local chat response: {}", text); let chat_resp: CompletionResponse = serde_json::from_str(&text) - .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + .map_err(|e| CompletionError::ResponseError(e.to_string()))?; let conv: completion::CompletionResponse = chat_resp.try_into()?; Ok(conv) } else { @@ -483,6 +582,11 @@ impl completion::CompletionModel for CompletionModel { let stream = Box::pin(stream! { let mut stream = response.bytes_stream(); + let mut tool_calls: HashMap, StreamingFunctionCall)> = HashMap::new(); + let mut final_response_id = "".to_string(); + let mut final_response_model = "".to_string(); + + while let Some(chunk_result) = stream.next().await { let chunk = match chunk_result { Ok(c) => c, @@ -501,63 +605,92 @@ impl completion::CompletionModel for CompletionModel { }; for line in text.lines() { - let line = line.trim(); - - if line.is_empty() { - continue; - } - - let data_line = if let Some(data) = line.strip_prefix("data: "){ - data - }else{ - line - }; + if line.starts_with("data: ") { + let data = &line[6..]; + if data == "[DONE]" { + break; + } - // stream termination like openai - if data_line == "[DONE]" { - break; + let Ok(chunk) = serde_json::from_str::(data) else { + continue; + }; + + final_response_id = chunk.id; + final_response_model = chunk.model; + + + for choice in chunk.choices { + if let Some(content) = choice.delta.content { + yield Ok(RawStreamingChoice::Message(content)); + } + + if let Some(delta_tool_calls) = choice.delta.tool_calls { + for stc in delta_tool_calls { + let entry = tool_calls.entry(stc.index).or_default(); + if let Some(id) = stc.id { + entry.0 = Some(id); + } + if let Some(function) = stc.function { + if let Some(name) = function.name { + entry.1.name.get_or_insert_with(String::new).push_str(&name); + } + if let Some(args) = function.arguments { + entry.1.arguments.get_or_insert_with(String::new).push_str(&args); + } + } + } + } + } } + } + } - let Ok(response) = serde_json::from_str::(data_line) else { + // yield any completed tool calls + for (_, (id, function)) in tool_calls { + if let (Some(id), Some(name), Some(arguments)) = (id, function.name, function.arguments) { + let Ok(args_json) = serde_json::from_str(&arguments) else { + yield Err(CompletionError::ResponseError(format!("Failed to parse tool call arguments: {}", arguments))); continue; }; - - for choice in response.choices.iter() { - if !choice.message.content.is_empty() { - yield Ok(RawStreamingChoice::Message(choice.message.content.clone())); - } - } - if response.choices.iter().any(|choice| choice.finish_reason == "stop") { - yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - id: response.id.clone(), - object: response.object.clone(), - created: response.created.clone(), - model: response.model.clone(), - usage: response.usage.clone(), - })); - } + yield Ok(RawStreamingChoice::ToolCall { + id: id.clone(), + name: name.clone(), + arguments: args_json, + call_id: None, + }); } } + + yield Ok(RawStreamingChoice::FinalResponse(StreamingFinalResponse { + id: final_response_id, + model: final_response_model, + })); }); Ok(streaming::StreamingCompletionResponse::stream(stream)) } } -#[derive(Debug, Serialize, Deserialize)] -pub enum Role { - #[serde(rename = "user")] - User, - #[serde(rename = "system")] - System, - #[serde(rename = "assistant")] - Assistant, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Message { - role: Role, - content: String, +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum Message { + System { + content: String, + }, + User { + content: String, + }, + Assistant { + // content can be null when tool_calls are present + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + }, + Tool { + tool_call_id: String, + content: String, + }, } impl TryFrom for Vec { @@ -566,41 +699,75 @@ impl TryFrom for Vec { use crate::message::Message as InternalMessage; match internal_msg { InternalMessage::User { content, .. } => { - // Foundry doesn't support tool results in messages, so we skip them - let non_tool_content: Vec<_> = content - .into_iter() - .filter(|content| { - !matches!(content, crate::message::UserContent::ToolResult(_)) - }) - .collect(); - - let text_contents: Vec = non_tool_content - .into_iter() - .filter_map(|content| match content { - crate::message::UserContent::Text(crate::message::Text { text }) => { - Some(text) + let mut messages = Vec::new(); + let mut text_parts = Vec::new(); + + for part in content { + match part { + message::UserContent::Text(text) => text_parts.push(text.text), + message::UserContent::ToolResult(result) => { + let content_string = result + .content + .into_iter() + .map(|c| match c { + message::ToolResultContent::Text(t) => t.text, + _ => "[unsupported content]".to_string(), + }) + .collect::>() + .join("\n"); + + messages.push(Message::Tool { + tool_call_id: result.id, + content: content_string, + }); } - _ => None, - }) - .collect(); + _ => {} + } + } - Ok(vec![Message { - role: Role::User, - content: text_contents.join(" "), - }]) + if !text_parts.is_empty() { + messages.insert( + 0, + Message::User { + content: text_parts.join("\n"), + }, + ); + } + + Ok(messages) } InternalMessage::Assistant { content, .. } => { - let text_contents: Vec = content - .into_iter() - .filter_map(|content| match content { - crate::message::AssistantContent::Text(text) => Some(text.text), - _ => None, - }) - .collect(); - - Ok(vec![Message { - role: Role::Assistant, - content: text_contents.join(" "), + let mut text_content = None; + let mut tool_calls = Vec::new(); + + for part in content { + match part { + message::AssistantContent::Text(text) => { + text_content + .get_or_insert_with(String::new) + .push_str(&text.text); + } + message::AssistantContent::ToolCall(tc) => { + tool_calls.push(ToolCall { + id: tc.id, + r#type: "function".to_string(), + function: FunctionCall { + name: tc.function.name, + arguments: tc.function.arguments.to_string(), + }, + }); + } + _ => {} + } + } + + Ok(vec![Message::Assistant { + content: text_content, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, }]) } } @@ -609,22 +776,50 @@ impl TryFrom for Vec { impl From for crate::completion::Message { fn from(msg: Message) -> Self { - match msg.role { - Role::User => crate::completion::Message::User { - content: OneOrMany::one(crate::completion::message::UserContent::Text(Text { - text: msg.content, - })), - }, - Role::Assistant => crate::completion::Message::Assistant { - id: None, - content: OneOrMany::one(crate::completion::message::AssistantContent::Text({ - Text { text: msg.content } - })), - }, - Role::System => crate::completion::Message::User { - content: OneOrMany::one(crate::completion::message::UserContent::Text(Text { - text: msg.content, - })), + match msg { + Message::User { content } | Message::System { content } => { + crate::completion::Message::User { + content: OneOrMany::one(crate::completion::message::UserContent::Text(Text { + text: content, + })), + } + } + Message::Assistant { + content, + tool_calls, + } => { + let mut assistant_contents = Vec::new(); + if let Some(text) = content { + if !text.is_empty() { + assistant_contents.push(message::AssistantContent::Text(Text { text })); + } + } + if let Some(tcs) = tool_calls { + for tc in tcs { + let arguments: Value = serde_json::from_str(&tc.function.arguments) + .unwrap_or_else(|_| json!(tc.function.arguments)); + assistant_contents.push(message::AssistantContent::tool_call( + tc.id, + tc.function.name, + arguments, + )); + } + } + + crate::completion::Message::Assistant { + id: None, + content: OneOrMany::many(assistant_contents) + .unwrap_or_else(|_| OneOrMany::one(message::AssistantContent::text(""))), + } + } + Message::Tool { + tool_call_id, + content, + } => crate::completion::Message::User { + content: OneOrMany::one(message::UserContent::tool_result( + tool_call_id, + OneOrMany::one(message::ToolResultContent::text(content)), + )), }, } } @@ -632,8 +827,7 @@ impl From for crate::completion::Message { impl Message { pub fn system(content: &str) -> Self { - Self { - role: Role::System, + Self::System { content: content.to_owned(), } } @@ -711,7 +905,7 @@ mod tests { let sample_chat_response = json!({ "id": "chatcmpl-1234567890", "object": "chat.completion", - "created": "1677851234", + "created": 1677851234, "model": "Phi-4-mini-instruct-generic-cpu", "choices": [ { @@ -741,10 +935,61 @@ mod tests { ); } + #[test] + fn test_tool_call_deserialization_and_conversion() { + let tool_call_response = json!({ + "id": "chatcmpl-9pFN3aGu2dM1ALf1IixE23qG1Wp7u", + "object": "chat.completion", + "created": 1720235377, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_stools_get_flight_info_1720235377043", + "type": "function", + "function": { + "name": "get_flight_info", + "arguments": "{\"origin_city\":\"Miami\",\"destination_city\":\"Seattle\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 83, + "completion_tokens": 21, + "total_tokens": 104 + } + }); + + let chat_resp: CompletionResponse = serde_json::from_value(tool_call_response).unwrap(); + let conv_resp: completion::CompletionResponse = + chat_resp.try_into().unwrap(); + + assert_eq!(conv_resp.choice.len(), 1); + match conv_resp.choice.first() { + completion::AssistantContent::ToolCall(tc) => { + assert_eq!(tc.id, "call_stools_get_flight_info_1720235377043"); + assert_eq!(tc.function.name, "get_flight_info"); + assert_eq!( + tc.function.arguments, + json!({"origin_city": "Miami", "destination_city": "Seattle"}) + ); + } + _ => panic!("Expected a tool call"), + } + } + #[test] fn test_message_conversion() { - let provider_msg = Message { - role: Role::User, + let provider_msg = Message::User { content: "Test message".to_owned(), }; let comp_msg: crate::completion::Message = provider_msg.into(); @@ -763,32 +1008,25 @@ mod tests { } #[test] - fn test_system_content_from_string() { - let content = SystemContent::from("Test system message".to_string()); - assert_eq!(content.text, "Test system message"); - assert!(matches!(content.r#type, SystemContentType::Text)); - } - - #[test] - fn test_system_content_from_str() { - let content: SystemContent = "Test system message".parse().unwrap(); - assert_eq!(content.text, "Test system message"); - assert!(matches!(content.r#type, SystemContentType::Text)); - } - - #[test] - fn test_assistant_content_from_str() { - let content: AssistantContent = "Test assistant message".parse().unwrap(); - assert_eq!(content.text, "Test assistant message"); - } + fn test_tool_result_message_conversion() { + let rig_message = crate::message::Message::User { + content: OneOrMany::one(crate::message::UserContent::tool_result( + "call_123", + OneOrMany::one(crate::message::ToolResultContent::text("Flight found")), + )), + }; - #[test] - fn test_user_content_from_str() { - let content: UserContent = "Test user message".parse().unwrap(); - match content { - UserContent::Text { text } => { - assert_eq!(text, "Test user message"); + let provider_messages: Vec = rig_message.try_into().unwrap(); + assert_eq!(provider_messages.len(), 1); + match &provider_messages[0] { + Message::Tool { + tool_call_id, + content, + } => { + assert_eq!(tool_call_id, "call_123"); + assert_eq!(content, "Flight found"); } + _ => panic!("Expected a Tool message"), } } } diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 1d5976346..f550d6769 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -49,7 +49,7 @@ pub mod anthropic; pub mod azure; pub mod cohere; pub mod deepseek; -pub mod foundry; +pub mod foundry_local; pub mod galadriel; pub mod gemini; pub mod groq; From 42e6f18fc8f6409da41f84b061baa02c391f3af4 Mon Sep 17 00:00:00 2001 From: raj Date: Sun, 17 Aug 2025 13:16:19 +0530 Subject: [PATCH 7/8] fix: lint --- rig-core/src/providers/foundry_local.rs | 66 +++++++++++++------------ 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/rig-core/src/providers/foundry_local.rs b/rig-core/src/providers/foundry_local.rs index 884c4e375..3e7972477 100644 --- a/rig-core/src/providers/foundry_local.rs +++ b/rig-core/src/providers/foundry_local.rs @@ -605,43 +605,44 @@ impl completion::CompletionModel for CompletionModel { }; for line in text.lines() { - if line.starts_with("data: ") { - let data = &line[6..]; - if data == "[DONE]" { - break; - } + if line.starts_with("data: ") + && let Some(data) = line.strip_prefix("data: "){ + if data == "[DONE]" { + break; + } - let Ok(chunk) = serde_json::from_str::(data) else { - continue; - }; + let Ok(chunk) = serde_json::from_str::(data) else { + continue; + }; - final_response_id = chunk.id; - final_response_model = chunk.model; + final_response_id = chunk.id; + final_response_model = chunk.model; - for choice in chunk.choices { - if let Some(content) = choice.delta.content { - yield Ok(RawStreamingChoice::Message(content)); - } + for choice in chunk.choices { + if let Some(content) = choice.delta.content { + yield Ok(RawStreamingChoice::Message(content)); + } - if let Some(delta_tool_calls) = choice.delta.tool_calls { - for stc in delta_tool_calls { - let entry = tool_calls.entry(stc.index).or_default(); - if let Some(id) = stc.id { - entry.0 = Some(id); - } - if let Some(function) = stc.function { - if let Some(name) = function.name { - entry.1.name.get_or_insert_with(String::new).push_str(&name); + if let Some(delta_tool_calls) = choice.delta.tool_calls { + for stc in delta_tool_calls { + let entry = tool_calls.entry(stc.index).or_default(); + if let Some(id) = stc.id { + entry.0 = Some(id); } - if let Some(args) = function.arguments { - entry.1.arguments.get_or_insert_with(String::new).push_str(&args); + if let Some(function) = stc.function { + if let Some(name) = function.name { + entry.1.name.get_or_insert_with(String::new).push_str(&name); + } + if let Some(args) = function.arguments { + entry.1.arguments.get_or_insert_with(String::new).push_str(&args); + } } } } } } - } + } } @@ -774,6 +775,7 @@ impl TryFrom for Vec { } } + impl From for crate::completion::Message { fn from(msg: Message) -> Self { match msg { @@ -789,11 +791,10 @@ impl From for crate::completion::Message { tool_calls, } => { let mut assistant_contents = Vec::new(); - if let Some(text) = content { - if !text.is_empty() { - assistant_contents.push(message::AssistantContent::Text(Text { text })); - } + if let Some(text) = content && !text.is_empty() { + assistant_contents.push(message::AssistantContent::Text(Text { text })); } + if let Some(tcs) = tool_calls { for tc in tcs { let arguments: Value = serde_json::from_str(&tc.function.arguments) @@ -808,8 +809,9 @@ impl From for crate::completion::Message { crate::completion::Message::Assistant { id: None, - content: OneOrMany::many(assistant_contents) - .unwrap_or_else(|_| OneOrMany::one(message::AssistantContent::text(""))), + content: OneOrMany::many(assistant_contents).unwrap_or_else(|_| { + OneOrMany::one(message::AssistantContent::text("")) + }), } } Message::Tool { From 480d9d5abbe6f1dddbdd50b803457f1c54cb3e98 Mon Sep 17 00:00:00 2001 From: raj Date: Sun, 17 Aug 2025 13:24:04 +0530 Subject: [PATCH 8/8] fix: clippy :( --- rig-core/src/providers/foundry_local.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rig-core/src/providers/foundry_local.rs b/rig-core/src/providers/foundry_local.rs index 3e7972477..b855a9023 100644 --- a/rig-core/src/providers/foundry_local.rs +++ b/rig-core/src/providers/foundry_local.rs @@ -775,7 +775,6 @@ impl TryFrom for Vec { } } - impl From for crate::completion::Message { fn from(msg: Message) -> Self { match msg { @@ -791,7 +790,9 @@ impl From for crate::completion::Message { tool_calls, } => { let mut assistant_contents = Vec::new(); - if let Some(text) = content && !text.is_empty() { + if let Some(text) = content + && !text.is_empty() + { assistant_contents.push(message::AssistantContent::Text(Text { text })); } @@ -809,9 +810,8 @@ impl From for crate::completion::Message { crate::completion::Message::Assistant { id: None, - content: OneOrMany::many(assistant_contents).unwrap_or_else(|_| { - OneOrMany::one(message::AssistantContent::text("")) - }), + content: OneOrMany::many(assistant_contents) + .unwrap_or_else(|_| OneOrMany::one(message::AssistantContent::text(""))), } } Message::Tool {