diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 8265c5222d63..551aeb3b03ee 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -24,6 +24,88 @@ struct DatabricksMessage { tool_call_id: Option, } +fn format_text_content(text: &str, image_format: &ImageFormat) -> (Vec, bool) { + let mut items = vec![json!({"type": "text", "text": text})]; + let has_image = if let Some(path) = detect_image_path(text) { + if let Ok(image) = load_image_file(path) { + items.push(convert_image(&image, image_format)); + } + true + } else { + false + }; + (items, has_image) +} + +fn format_tool_response( + response: &crate::conversation::message::ToolResponse, + image_format: &ImageFormat, +) -> Vec { + let mut result = Vec::new(); + + match &response.tool_result { + Ok(call_result) => { + let abridged: Vec<_> = call_result + .content + .iter() + .filter(|c| c.audience().is_none_or(|a| a.contains(&Role::Assistant))) + .map(|c| c.raw.clone()) + .collect(); + + let mut tool_content = Vec::new(); + let mut image_messages = Vec::new(); + + for content in abridged { + match content { + RawContent::Image(image) => { + tool_content.push(Content::text( + "This tool result included an image that is uploaded in the next message.", + )); + image_messages.push(DatabricksMessage { + role: "user".to_string(), + content: [convert_image(&image.no_annotation(), image_format)].into(), + tool_calls: None, + tool_call_id: None, + }); + } + RawContent::Resource(resource) => { + let text = match &resource.resource { + ResourceContents::TextResourceContents { text, .. } => text.clone(), + _ => String::new(), + }; + tool_content.push(Content::text(text)); + } + _ => tool_content.push(content.no_annotation()), + } + } + + let tool_response_content: Value = json!(tool_content + .iter() + .filter_map(|c| c.as_text().map(|t| t.text.clone())) + .collect::>() + .join(" ")); + + result.push(DatabricksMessage { + content: tool_response_content, + role: "tool".to_string(), + tool_call_id: Some(response.id.clone()), + tool_calls: None, + }); + result.extend(image_messages); + } + Err(e) => { + result.push(DatabricksMessage { + role: "tool".to_string(), + content: format!("The tool call returned the following error:\n{}", e).into(), + tool_call_id: Some(response.id.clone()), + tool_calls: None, + }); + } + } + + result +} + /// Convert internal Message format to Databricks' API message specification /// Databricks is mostly OpenAI compatible, but has some differences (reasoning type, etc) /// some openai compatible endpoints use the anthropic image spec at the content level @@ -49,53 +131,27 @@ fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - has_multiple_content = true; - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - content_array.push(convert_image(&image, image_format)); - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } + let (items, multi) = format_text_content(&text.text, image_format); + content_array.extend(items); + has_multiple_content |= multi; } } MessageContent::Thinking(content) => { has_multiple_content = true; content_array.push(json!({ "type": "reasoning", - "summary": [ - { - "type": "summary_text", - "text": content.thinking, - "signature": content.signature - } - ] + "summary": [{ + "type": "summary_text", + "text": content.thinking, + "signature": content.signature + }] })); } MessageContent::RedactedThinking(content) => { has_multiple_content = true; content_array.push(json!({ "type": "reasoning", - "summary": [ - { - "type": "summary_encrypted_text", - "data": content.data - } - ] + "summary": [{"type": "summary_encrypted_text", "data": content.data}] })); } MessageContent::ToolRequest(request) => { @@ -103,152 +159,58 @@ fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec { let sanitized_name = sanitize_function_name(&tool_call.name); - let arguments_str = match &tool_call.arguments { - Some(args) => { + let arguments_str = tool_call + .arguments + .as_ref() + .map(|args| { serde_json::to_string(args).unwrap_or_else(|_| "{}".to_string()) - } - None => "{}".to_string(), - }; + }) + .unwrap_or_else(|| "{}".to_string()); - let tool_calls = converted.tool_calls.get_or_insert_default(); - tool_calls.push(json!({ + converted.tool_calls.get_or_insert_default().push(json!({ "id": request.id, "type": "function", - "function": { - "name": sanitized_name, - "arguments": arguments_str, - } + "function": {"name": sanitized_name, "arguments": arguments_str} })); } Err(e) => { - content_array.push(json!({ - "type": "text", - "text": format!("Error: {}", e) - })); + content_array + .push(json!({"type": "text", "text": format!("Error: {}", e)})); } } } - MessageContent::SystemNotification(_) => { - continue; - } MessageContent::ToolResponse(response) => { - match &response.tool_result { - Ok(call_result) => { - // Send only contents with no audience or with Assistant in the audience - let abridged: Vec<_> = call_result - .content - .iter() - .filter(|content| { - content - .audience() - .is_none_or(|audience| audience.contains(&Role::Assistant)) - }) - .map(|content| content.raw.clone()) - .collect(); - - // Process all content, replacing images with placeholder text - let mut tool_content = Vec::new(); - let mut image_messages = Vec::new(); - - for content in abridged { - match content { - RawContent::Image(image) => { - tool_content.push(Content::text("This tool result included an image that is uploaded in the next message.")); - image_messages.push(DatabricksMessage { - role: "user".to_string(), - content: [convert_image( - &image.no_annotation(), - image_format, - )] - .into(), - tool_calls: None, - tool_call_id: None, - }); - } - RawContent::Resource(resource) => { - let text = match &resource.resource { - ResourceContents::TextResourceContents { - text, .. - } => text.clone(), - _ => String::new(), - }; - tool_content.push(Content::text(text)); - } - _ => { - tool_content.push(content.no_annotation()); - } - } - } - let tool_response_content: Value = json!(tool_content - .iter() - .filter_map(|content| content.as_text().map(|t| t.text.clone())) - .collect::>() - .join(" ")); - - result.push(DatabricksMessage { - content: tool_response_content, - role: "tool".to_string(), - tool_call_id: Some(response.id.clone()), - tool_calls: None, - }); - // Then add any image messages that need to follow - result.extend(image_messages); - } - Err(e) => { - // A tool result error is shown as output so the model can interpret the error message - result.push(DatabricksMessage { - role: "tool".to_string(), - content: format!( - "The tool call returned the following error:\n{}", - e - ) - .into(), - tool_call_id: Some(response.id.clone()), - tool_calls: None, - }); - } - } + result.extend(format_tool_response(response, image_format)); } - MessageContent::ToolConfirmationRequest(_) => {} - MessageContent::ActionRequired(_) => {} MessageContent::Image(image) => { content_array.push(convert_image(image, image_format)); } MessageContent::FrontendToolRequest(req) => { - // Frontend tool requests are converted to text messages - if let Ok(tool_call) = &req.tool_call { - content_array.push(json!({ - "type": "text", - "text": format!( - "Frontend tool request: {} ({})", - tool_call.name, - serde_json::to_string_pretty(&tool_call.arguments).unwrap() - ) - })); - } else { - content_array.push(json!({ - "type": "text", - "text": format!( - "Frontend tool request error: {}", - req.tool_call.as_ref().unwrap_err() - ) - })); - } + let text = match &req.tool_call { + Ok(tool_call) => format!( + "Frontend tool request: {} ({})", + tool_call.name, + serde_json::to_string_pretty(&tool_call.arguments).unwrap() + ), + Err(e) => format!("Frontend tool request error: {}", e), + }; + content_array.push(json!({"type": "text", "text": text})); } + MessageContent::SystemNotification(_) + | MessageContent::ToolConfirmationRequest(_) + | MessageContent::ActionRequired(_) => {} } } if !content_array.is_empty() { - // If we only have a single text content and no other special content, - // use the simple string format - if content_array.len() == 1 + converted.content = if content_array.len() == 1 && !has_multiple_content && content_array[0]["type"] == "text" { - converted.content = json!(content_array[0]["text"]); + json!(content_array[0]["text"]) } else { - converted.content = json!(content_array); - } + json!(content_array) + }; } if !content_array.is_empty() || has_tool_calls {