Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

124 changes: 89 additions & 35 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
use anyhow::Result;
use async_stream::try_stream;
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io;
use std::time::Duration;
use tokio::pin;
use tokio_util::codec::{FramedRead, LinesCodec};
use tokio_util::io::StreamReader;

use super::api_client::{ApiClient, AuthMethod, AuthProvider};
use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata};
use super::embedding::EmbeddingCapable;
use super::errors::ProviderError;
use super::formats::databricks::create_request;
use super::formats::openai_responses::{
create_responses_request, responses_api_to_streaming_message,
};
use super::oauth;
use super::openai_compatible::{
handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat,
handle_response_openai_compat, handle_status_openai_compat, map_http_error_to_provider_error,
stream_openai_compat,
};
use super::retry::ProviderRetry;
use super::utils::{ImageFormat, RequestLog};
Expand Down Expand Up @@ -208,9 +218,16 @@ impl DatabricksProvider {
})
}

fn is_responses_model(model_name: &str) -> bool {
let normalized = model_name.to_ascii_lowercase();
normalized.contains("codex")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route Responses API using capability, not name matching

Switching to the Responses endpoint based on normalized.contains("codex") causes incorrect routing for Databricks because this provider treats model_name as a serving endpoint name (see fetch_supported_models returning endpoint.name), not a guaranteed model-family identifier. Any endpoint whose name happens to include codex will now be forced to serving-endpoints/responses, and codex-backed endpoints without that substring will still use /invocations, so requests can be sent to the wrong API path despite valid endpoint configuration.

Useful? React with 👍 / 👎.

}

fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String {
if is_embedding {
"serving-endpoints/text-embedding-3-small/invocations".to_string()
} else if Self::is_responses_model(model_name) {
"serving-endpoints/responses".to_string()
} else {
format!("serving-endpoints/{}/invocations", model_name)
}
Expand Down Expand Up @@ -282,42 +299,79 @@ impl Provider for DatabricksProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let mut payload =
create_request(model_config, system, messages, tools, &self.image_format)?;
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");

payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));

let path = self.get_endpoint_path(&model_config.model_name, false);
let mut log = RequestLog::start(model_config, &payload)?;
let response = self
.with_retry(|| async {
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload)
.await?;
if !resp.status().is_success() {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_default();

// Parse as JSON if possible to pass to map_http_error_to_provider_error
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
return Err(map_http_error_to_provider_error(status, json_payload));
}
Ok(resp)
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;

stream_openai_compat(response, log)
if Self::is_responses_model(&model_config.model_name) {
let mut payload = create_responses_request(model_config, system, messages, tools)?;
payload["stream"] = Value::Bool(true);

let mut log = RequestLog::start(model_config, &payload)?;

let response = self
.with_retry(|| async {
let payload_clone = payload.clone();
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload_clone)
.await?;
handle_status_openai_compat(resp).await
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;

let stream = response.bytes_stream().map_err(io::Error::other);

Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);

let message_stream = responses_api_to_streaming_message(framed);
pin!(message_stream);
while let Some(message) = message_stream.next().await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
yield (message, usage);
}
}))
} else {
let mut payload =
create_request(model_config, system, messages, tools, &self.image_format)?;
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");

payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));

let mut log = RequestLog::start(model_config, &payload)?;
let response = self
.with_retry(|| async {
let resp = self
.api_client
.response_post(Some(session_id), &path, &payload)
.await?;
if !resp.status().is_success() {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_default();

// Parse as JSON if possible to pass to map_http_error_to_provider_error
let json_payload = serde_json::from_str::<Value>(&error_text).ok();
return Err(map_http_error_to_provider_error(status, json_payload));
}
Ok(resp)
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;

stream_openai_compat(response, log)
}
}

fn supports_embeddings(&self) -> bool {
Expand Down