Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions bindings/kotlin/uniffi/goose_llm/goose_llm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ internal interface UniffiLib : Library {
`systemPromptOverride`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`extensions`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
uniffi_out_err: UniffiRustCallStatus,
): RustBuffer.ByValue

Expand All @@ -848,6 +849,7 @@ internal interface UniffiLib : Library {
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long

fun uniffi_goose_llm_fn_func_generate_structured_outputs(
Expand All @@ -856,12 +858,14 @@ internal interface UniffiLib : Library {
`systemPrompt`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`schema`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long

fun uniffi_goose_llm_fn_func_generate_tooltip(
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long

fun uniffi_goose_llm_fn_func_print_messages(
Expand Down Expand Up @@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) {
Expand Down Expand Up @@ -2960,6 +2964,7 @@ fun `createCompletionRequest`(
`systemPromptOverride`: kotlin.String? = null,
`messages`: List<Message>,
`extensions`: List<ExtensionConfig>,
`requestId`: kotlin.String? = null,
): CompletionRequest =
FfiConverterTypeCompletionRequest.lift(
uniffiRustCall { _status ->
Expand All @@ -2971,6 +2976,7 @@ fun `createCompletionRequest`(
FfiConverterOptionalString.lower(`systemPromptOverride`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterSequenceTypeExtensionConfig.lower(`extensions`),
FfiConverterOptionalString.lower(`requestId`),
_status,
)
},
Expand Down Expand Up @@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`messages`: List<Message>,
`requestId`: kotlin.String? = null,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
Expand All @@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`(
`systemPrompt`: kotlin.String,
`messages`: List<Message>,
`schema`: Value,
`requestId`: kotlin.String? = null,
): ProviderExtractResponse =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs(
Expand All @@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`(
FfiConverterString.lower(`systemPrompt`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterTypeValue.lower(`schema`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
Expand All @@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`messages`: List<Message>,
`requestId`: kotlin.String? = null,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
Expand Down
33 changes: 17 additions & 16 deletions crates/goose-llm/examples/image.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::{vec, fs};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use anyhow::Result;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use goose_llm::{
completion,
types::completion::{
CompletionRequest, CompletionResponse
},
message::MessageContent,
Message, ModelConfig,
types::completion::{CompletionRequest, CompletionResponse},
Message, ModelConfig,
};
use serde_json::json;
use std::{fs, vec};

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -18,7 +16,7 @@ async fn main() -> Result<()> {
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
let model_name = "kgoose-claude-4-sonnet"; // "gpt-4o";
let model_name = "goose-claude-4-sonnet"; // "gpt-4o";
let model_config = ModelConfig::new(model_name.to_string());

let system_preamble = "You are a helpful assistant.";
Expand All @@ -33,15 +31,18 @@ async fn main() -> Result<()> {

let messages = vec![user_msg];

let completion_response: CompletionResponse = completion(CompletionRequest::new(
provider.to_string(),
provider_config.clone(),
model_config.clone(),
Some(system_preamble.to_string()),
None,
messages,
vec![],
))
let completion_response: CompletionResponse = completion(
CompletionRequest::new(
provider.to_string(),
provider_config.clone(),
model_config.clone(),
Some(system_preamble.to_string()),
None,
messages,
vec![],
)
.with_request_id("test-image-1".to_string()),
)
.await?;

// Print the response
Expand Down
2 changes: 1 addition & 1 deletion crates/goose-llm/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async fn main() -> Result<()> {
println!("\nCompletion Response:");
println!("{}", serde_json::to_string_pretty(&completion_response)?);

let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?;
let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?;
println!("\nTooltip: {}", tooltip);
}

Expand Down
7 changes: 6 additions & 1 deletion crates/goose-llm/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
// Call the LLM provider
let start_provider = Instant::now();
let mut response = provider
.complete(&system_prompt, &req.messages, &tools)
.complete(
&system_prompt,
&req.messages,
&tools,
req.request_id.as_deref(),
)
.await?;
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
let usage_tokens = response.usage.total_tokens;
Expand Down
4 changes: 3 additions & 1 deletion crates/goose-llm/src/extractors/session_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ fn build_system_prompt() -> String {
}

/// Generates a short (≤4 words) session name
#[uniffi::export(async_runtime = "tokio")]
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
pub async fn generate_session_name(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
request_id: Option<String>,
) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages
Expand Down Expand Up @@ -93,6 +94,7 @@ pub async fn generate_session_name(
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
request_id,
)
.await?;

Expand Down
4 changes: 3 additions & 1 deletion crates/goose-llm/src/extractors/tooltip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ fn build_system_prompt() -> String {

/// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results.
#[uniffi::export(async_runtime = "tokio")]
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
pub async fn generate_tooltip(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
request_id: Option<String>,
) -> Result<String, ProviderError> {
// Need at least two messages to generate a tooltip
if messages.len() < 2 {
Expand Down Expand Up @@ -148,6 +149,7 @@ pub async fn generate_tooltip(
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
request_id,
)
.await?;

Expand Down
4 changes: 4 additions & 0 deletions crates/goose-llm/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub trait Provider: Send + Sync {
/// * `system` - The system prompt that guides the model's behavior
/// * `messages` - The conversation history as a sequence of messages
/// * `tools` - Optional list of tools the model can use
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
///
/// # Returns
/// A tuple containing the model's response message and provider usage statistics
Expand All @@ -81,6 +82,7 @@ pub trait Provider: Send + Sync {
system: &str,
messages: &[Message],
tools: &[Tool],
request_id: Option<&str>,
) -> Result<ProviderCompleteResponse, ProviderError>;

/// Structured extraction: always JSON‐Schema
Expand All @@ -90,6 +92,7 @@ pub trait Provider: Send + Sync {
/// * `messages` – conversation history
/// * `schema` – a JSON‐Schema for the expected output.
/// Will set strict=true for OpenAI & Databricks.
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
///
/// # Returns
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
Expand All @@ -102,6 +105,7 @@ pub trait Provider: Send + Sync {
system: &str,
messages: &[Message],
schema: &serde_json::Value,
request_id: Option<&str>,
) -> Result<ProviderExtractResponse, ProviderError>;
}

Expand Down
24 changes: 24 additions & 0 deletions crates/goose-llm/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ impl Provider for DatabricksProvider {
system: &str,
messages: &[Message],
tools: &[Tool],
request_id: Option<&str>,
) -> Result<ProviderCompleteResponse, ProviderError> {
let mut payload = create_request(
&self.model,
Expand All @@ -224,6 +225,17 @@ impl Provider for DatabricksProvider {
.expect("payload should have model key")
.remove("model");

// Add client_request_id if provided
if let Some(req_id) = request_id {
payload
.as_object_mut()
.expect("payload should be an object")
.insert(
"client_request_id".to_string(),
serde_json::Value::String(req_id.to_string()),
);
}

let response = self.post(payload.clone()).await?;

// Parse response
Expand All @@ -247,6 +259,7 @@ impl Provider for DatabricksProvider {
system: &str,
messages: &[Message],
schema: &Value,
request_id: Option<&str>,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
Expand All @@ -267,6 +280,17 @@ impl Provider for DatabricksProvider {
}),
);

// Add client_request_id if provided
if let Some(req_id) = request_id {
payload
.as_object_mut()
.expect("payload should be an object")
.insert(
"client_request_id".to_string(),
serde_json::Value::String(req_id.to_string()),
);
}

// 3. Call OpenAI
let response = self.post(payload.clone()).await?;

Expand Down
Loading
Loading