diff --git a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt index 76e60aaf7441..f01956947261 100644 --- a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt +++ b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt @@ -833,6 +833,7 @@ internal interface UniffiLib : Library { `systemPromptOverride`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `extensions`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, uniffi_out_err: UniffiRustCallStatus, ): RustBuffer.ByValue @@ -848,6 +849,7 @@ internal interface UniffiLib : Library { `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -856,12 +858,14 @@ internal interface UniffiLib : Library { `systemPrompt`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `schema`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_tooltip( `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_print_messages( @@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) { if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) { @@ -2960,6 +2964,7 @@ fun `createCompletionRequest`( `systemPromptOverride`: kotlin.String? = null, `messages`: List, `extensions`: List, + `requestId`: kotlin.String? = null, ): CompletionRequest = FfiConverterTypeCompletionRequest.lift( uniffiRustCall { _status -> @@ -2971,6 +2976,7 @@ fun `createCompletionRequest`( FfiConverterOptionalString.lower(`systemPromptOverride`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterSequenceTypeExtensionConfig.lower(`extensions`), + FfiConverterOptionalString.lower(`requestId`), _status, ) }, @@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`( `systemPrompt`: kotlin.String, `messages`: List, `schema`: Value, + `requestId`: kotlin.String? = null, ): ProviderExtractResponse = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`( FfiConverterString.lower(`systemPrompt`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterTypeValue.lower(`schema`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, diff --git a/crates/goose-llm/examples/image.rs b/crates/goose-llm/examples/image.rs index 63da3951d409..7c607713e9cf 100644 --- a/crates/goose-llm/examples/image.rs +++ b/crates/goose-llm/examples/image.rs @@ -1,15 +1,13 @@ -use std::{vec, fs}; -use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use anyhow::Result; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use goose_llm::{ completion, - types::completion::{ - CompletionRequest, CompletionResponse - }, message::MessageContent, - Message, ModelConfig, + types::completion::{CompletionRequest, CompletionResponse}, + Message, ModelConfig, }; use serde_json::json; +use std::{fs, vec}; #[tokio::main] async fn main() -> Result<()> { @@ -18,7 +16,7 @@ async fn main() -> Result<()> { "host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"), "token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"), }); - let model_name = "kgoose-claude-4-sonnet"; // "gpt-4o"; + let model_name = "goose-claude-4-sonnet"; // "gpt-4o"; let model_config = ModelConfig::new(model_name.to_string()); let system_preamble = "You are a helpful assistant."; @@ -33,15 +31,18 @@ async fn main() -> Result<()> { let messages = vec![user_msg]; - let completion_response: CompletionResponse = completion(CompletionRequest::new( - provider.to_string(), - provider_config.clone(), - model_config.clone(), - Some(system_preamble.to_string()), - None, - messages, - vec![], - )) + let completion_response: CompletionResponse = completion( + CompletionRequest::new( + provider.to_string(), + provider_config.clone(), + model_config.clone(), + Some(system_preamble.to_string()), + None, + messages, + vec![], + ) + .with_request_id("test-image-1".to_string()), + ) .await?; // Print the response diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs index e7d36a7870ed..efab4b0abc57 100644 --- a/crates/goose-llm/examples/simple.rs +++ b/crates/goose-llm/examples/simple.rs @@ -116,7 +116,7 @@ async fn main() -> Result<()> { println!("\nCompletion Response:"); println!("{}", serde_json::to_string_pretty(&completion_response)?); - let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?; + let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?; println!("\nTooltip: {}", tooltip); } diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index d39b1b8db830..13f09810b8c4 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result String { } /// Generates a short (≤4 words) session name -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_session_name( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Collect up to the first 3 user messages (truncated to 300 chars each) let context: Vec = messages @@ -93,6 +94,7 @@ pub async fn generate_session_name( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/extractors/tooltip.rs b/crates/goose-llm/src/extractors/tooltip.rs index 37d83ffe59e1..48336a546ea6 100644 --- a/crates/goose-llm/src/extractors/tooltip.rs +++ b/crates/goose-llm/src/extractors/tooltip.rs @@ -52,11 +52,12 @@ fn build_system_prompt() -> String { /// Generates a tooltip summarizing the last two messages in the session, /// including any tool calls or results. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_tooltip( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Need at least two messages to generate a tooltip if messages.len() < 2 { @@ -148,6 +149,7 @@ pub async fn generate_tooltip( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/providers/base.rs b/crates/goose-llm/src/providers/base.rs index dcfecbd1e7f3..92a3948df28f 100644 --- a/crates/goose-llm/src/providers/base.rs +++ b/crates/goose-llm/src/providers/base.rs @@ -69,6 +69,7 @@ pub trait Provider: Send + Sync { /// * `system` - The system prompt that guides the model's behavior /// * `messages` - The conversation history as a sequence of messages /// * `tools` - Optional list of tools the model can use + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A tuple containing the model's response message and provider usage statistics @@ -81,6 +82,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result; /// Structured extraction: always JSON‐Schema @@ -90,6 +92,7 @@ pub trait Provider: Send + Sync { /// * `messages` – conversation history /// * `schema` – a JSON‐Schema for the expected output. /// Will set strict=true for OpenAI & Databricks. + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`. @@ -102,6 +105,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], schema: &serde_json::Value, + request_id: Option<&str>, ) -> Result; } diff --git a/crates/goose-llm/src/providers/databricks.rs b/crates/goose-llm/src/providers/databricks.rs index 3dd31493c1cd..0bfe2ffef67b 100644 --- a/crates/goose-llm/src/providers/databricks.rs +++ b/crates/goose-llm/src/providers/databricks.rs @@ -210,6 +210,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result { let mut payload = create_request( &self.model, @@ -224,6 +225,17 @@ impl Provider for DatabricksProvider { .expect("payload should have model key") .remove("model"); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + let response = self.post(payload.clone()).await?; // Parse response @@ -247,6 +259,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], schema: &Value, + request_id: Option<&str>, ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; @@ -267,6 +280,17 @@ impl Provider for DatabricksProvider { }), ); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + // 3. Call OpenAI let response = self.post(payload.clone()).await?; diff --git a/crates/goose-llm/src/providers/formats/databricks.rs b/crates/goose-llm/src/providers/formats/databricks.rs index ac01e2cd789a..37343f2ebe09 100644 --- a/crates/goose-llm/src/providers/formats/databricks.rs +++ b/crates/goose-llm/src/providers/formats/databricks.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -34,30 +31,17 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - has_multiple_content = true; - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - content_array.push(convert_image(&image, image_format)); - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } + content_array.push(json!({ + "type": "text", + "text": text.text + })); } } + MessageContent::Image(image) => { + // Handle direct image content + let converted_image = convert_image(image, image_format); + content_array.push(converted_image); + } MessageContent::Thinking(content) => { has_multiple_content = true; content_array.push(json!({ @@ -166,11 +150,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - let converted_image = convert_image(image, image_format); - content_array.push(converted_image); - } } } @@ -787,40 +766,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/formats/openai.rs b/crates/goose-llm/src/providers/formats/openai.rs index afc48745cb59..a2eb43b414eb 100644 --- a/crates/goose-llm/src/providers/formats/openai.rs +++ b/crates/goose-llm/src/providers/formats/openai.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -31,23 +28,13 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - converted["content"] = json!([ - {"type": "text", "text": text.text}, - convert_image(&image, image_format) - ]); - } else { - // If image loading fails, just use the text - converted["content"] = json!(text.text); - } - } else { - converted["content"] = json!(text.text); - } + converted["content"] = json!(text.text); } } + MessageContent::Image(image) => { + // Handle direct image content + converted["content"] = json!([convert_image(image, image_format)]); + } MessageContent::Thinking(_) => { // Thinking blocks are not directly used in OpenAI format continue; @@ -134,10 +121,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - converted["content"] = json!([convert_image(image, image_format)]); - } } } @@ -664,40 +647,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/openai.rs b/crates/goose-llm/src/providers/openai.rs index bc0dc0884823..82d736f366cf 100644 --- a/crates/goose-llm/src/providers/openai.rs +++ b/crates/goose-llm/src/providers/openai.rs @@ -149,6 +149,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], tools: &[Tool], + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; @@ -175,6 +176,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], schema: &Value, + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; diff --git a/crates/goose-llm/src/providers/utils.rs b/crates/goose-llm/src/providers/utils.rs index 1a3945dcb15c..b6c00e7bf237 100644 --- a/crates/goose-llm/src/providers/utils.rs +++ b/crates/goose-llm/src/providers/utils.rs @@ -181,30 +181,6 @@ fn is_image_file(path: &Path) -> bool { false } -/// Detect if a string contains a path to an image file -pub fn detect_image_path(text: &str) -> Option<&str> { - // Basic image file extension check - let extensions = [".png", ".jpg", ".jpeg"]; - - // Find any word that ends with an image extension - for word in text.split_whitespace() { - if extensions - .iter() - .any(|ext| word.to_lowercase().ends_with(ext)) - { - let path = Path::new(word); - // Check if it's an absolute path and file exists - if path.is_absolute() && path.is_file() { - // Verify it's actually an image file - if is_image_file(path) { - return Some(word); - } - } - } - } - None -} - /// Convert a local image file to base64 encoded ImageContent pub fn load_image_file(path: &str) -> Result { let path = Path::new(path); @@ -267,81 +243,6 @@ pub fn emit_debug_trace( mod tests { use super::*; - #[test] - fn test_detect_image_path() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - - // Test with valid PNG file using absolute path - let text = format!("Here is an image {}", png_path_str); - assert_eq!(detect_image_path(&text), Some(png_path_str)); - - // Test with non-image file that has .png extension - let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); - assert_eq!(detect_image_path(&text), None); - - // Test with non-existent file - let text = "Here is a fake.png that doesn't exist"; - assert_eq!(detect_image_path(text), None); - - // Test with non-image file - let text = "Here is a file.txt"; - assert_eq!(detect_image_path(text), None); - - // Test with relative path (should not match) - let text = "Here is a relative/path/image.png"; - assert_eq!(detect_image_path(text), None); - } - - #[test] - fn test_load_image_file() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - let fake_png_path_str = fake_png_path.to_str().unwrap(); - - // Test loading valid PNG file - let result = load_image_file(png_path_str); - assert!(result.is_ok()); - let image = result.unwrap(); - assert_eq!(image.mime_type, "image/png"); - - // Test loading fake PNG file - let result = load_image_file(fake_png_path_str); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("not a valid image")); - - // Test non-existent file - let result = load_image_file("nonexistent.png"); - assert!(result.is_err()); - } - #[test] fn test_sanitize_function_name() { assert_eq!(sanitize_function_name("hello-world"), "hello-world"); diff --git a/crates/goose-llm/src/structured_outputs.rs b/crates/goose-llm/src/structured_outputs.rs index 8f478d8aa184..b6690b641e74 100644 --- a/crates/goose-llm/src/structured_outputs.rs +++ b/crates/goose-llm/src/structured_outputs.rs @@ -6,13 +6,14 @@ use crate::{ /// Generates a structured output based on the provided schema, /// system prompt and user messages. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_structured_outputs( provider_name: &str, provider_config: JsonValueFfi, system_prompt: &str, messages: &[Message], schema: JsonValueFfi, + request_id: Option, ) -> Result { // Use OpenAI models specifically for this task let model_name = if provider_name == "databricks" { @@ -23,7 +24,9 @@ pub async fn generate_structured_outputs( let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0)); let provider = create(provider_name, provider_config, model_cfg)?; - let resp = provider.extract(system_prompt, messages, &schema).await?; + let resp = provider + .extract(system_prompt, messages, &schema, request_id.as_deref()) + .await?; Ok(resp) } diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index 21e0bcd9ddd3..e6a6fd22ff20 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -20,6 +20,7 @@ pub struct CompletionRequest { pub system_prompt_override: Option, pub messages: Vec, pub extensions: Vec, + pub request_id: Option, } impl CompletionRequest { @@ -40,11 +41,17 @@ impl CompletionRequest { system_preamble, messages, extensions, + request_id: None, } } + + pub fn with_request_id(mut self, request_id: String) -> Self { + self.request_id = Some(request_id); + self + } } -#[uniffi::export(default(system_preamble = None, system_prompt_override = None))] +#[uniffi::export(default(system_preamble = None, system_prompt_override = None, request_id = None))] pub fn create_completion_request( provider_name: &str, provider_config: JsonValueFfi, @@ -53,8 +60,9 @@ pub fn create_completion_request( system_prompt_override: Option, messages: Vec, extensions: Vec, + request_id: Option, ) -> CompletionRequest { - CompletionRequest::new( + let mut request = CompletionRequest::new( provider_name.to_string(), provider_config, model_config, @@ -62,7 +70,13 @@ pub fn create_completion_request( system_prompt_override, messages, extensions, - ) + ); + + if let Some(req_id) = request_id { + request = request.with_request_id(req_id); + } + + request } uniffi::custom_type!(CompletionRequest, String, { diff --git a/crates/goose-llm/tests/extract_session_name.rs b/crates/goose-llm/tests/extract_session_name.rs index 5326fdbe780d..58d0a6b4921e 100644 --- a/crates/goose-llm/tests/extract_session_name.rs +++ b/crates/goose-llm/tests/extract_session_name.rs @@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result Result {}", provider_type, resp.data);