From 406428bb9b06dfb2984e5259c4e6689f870e9db1 Mon Sep 17 00:00:00 2001 From: David Katz Date: Wed, 25 Feb 2026 13:28:00 -0500 Subject: [PATCH] initial impl --- Cargo.lock | 2 +- crates/goose/src/providers/databricks.rs | 124 ++++++++++++++++------- 2 files changed, 90 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4282ebe30e7f..646da1978b99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4375,7 +4375,7 @@ dependencies = [ [[package]] name = "goose-acp-macros" -version = "1.24.0" +version = "1.25.0" dependencies = [ "proc-macro2", "quote", diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index a215ec4dd93e..728d539a8065 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,18 +1,28 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; use futures::future::BoxFuture; +use futures::{StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::io; use std::time::Duration; +use tokio::pin; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tokio_util::io::StreamReader; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::create_request; +use super::formats::openai_responses::{ + create_responses_request, responses_api_to_streaming_message, +}; use super::oauth; use super::openai_compatible::{ - handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat, + handle_response_openai_compat, handle_status_openai_compat, map_http_error_to_provider_error, + stream_openai_compat, }; use super::retry::ProviderRetry; use super::utils::{ImageFormat, RequestLog}; @@ -208,9 +218,16 @@ impl DatabricksProvider { }) } + fn is_responses_model(model_name: &str) -> bool { + let normalized = model_name.to_ascii_lowercase(); + normalized.contains("codex") + } + fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String { if is_embedding { "serving-endpoints/text-embedding-3-small/invocations".to_string() + } else if Self::is_responses_model(model_name) { + "serving-endpoints/responses".to_string() } else { format!("serving-endpoints/{}/invocations", model_name) } @@ -282,42 +299,79 @@ impl Provider for DatabricksProvider { messages: &[Message], tools: &[Tool], ) -> Result { - let mut payload = - create_request(model_config, system, messages, tools, &self.image_format)?; - payload - .as_object_mut() - .expect("payload should have model key") - .remove("model"); - - payload - .as_object_mut() - .unwrap() - .insert("stream".to_string(), Value::Bool(true)); - let path = self.get_endpoint_path(&model_config.model_name, false); - let mut log = RequestLog::start(model_config, &payload)?; - let response = self - .with_retry(|| async { - let resp = self - .api_client - .response_post(Some(session_id), &path, &payload) - .await?; - if !resp.status().is_success() { - let status = resp.status(); - let error_text = resp.text().await.unwrap_or_default(); - - // Parse as JSON if possible to pass to map_http_error_to_provider_error - let json_payload = serde_json::from_str::(&error_text).ok(); - return Err(map_http_error_to_provider_error(status, json_payload)); - } - Ok(resp) - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - stream_openai_compat(response, log) + if Self::is_responses_model(&model_config.model_name) { + let mut payload = create_responses_request(model_config, system, messages, tools)?; + payload["stream"] = Value::Bool(true); + + let mut log = RequestLog::start(model_config, &payload)?; + + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + let resp = self + .api_client + .response_post(Some(session_id), &path, &payload_clone) + .await?; + handle_status_openai_compat(resp).await + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; + + let stream = response.bytes_stream().map_err(io::Error::other); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = responses_api_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; + yield (message, usage); + } + })) + } else { + let mut payload = + create_request(model_config, system, messages, tools, &self.image_format)?; + payload + .as_object_mut() + .expect("payload should have model key") + .remove("model"); + + payload + .as_object_mut() + .unwrap() + .insert("stream".to_string(), Value::Bool(true)); + + let mut log = RequestLog::start(model_config, &payload)?; + let response = self + .with_retry(|| async { + let resp = self + .api_client + .response_post(Some(session_id), &path, &payload) + .await?; + if !resp.status().is_success() { + let status = resp.status(); + let error_text = resp.text().await.unwrap_or_default(); + + // Parse as JSON if possible to pass to map_http_error_to_provider_error + let json_payload = serde_json::from_str::(&error_text).ok(); + return Err(map_http_error_to_provider_error(status, json_payload)); + } + Ok(resp) + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; + + stream_openai_compat(response, log) + } } fn supports_embeddings(&self) -> bool {