Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use tokio_stream::StreamExt;
use tokio_util::codec::{FramedRead, LinesCodec};

const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
const DEFAULT_REDIRECT_URL: &str = "http://localhost";
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
const DEFAULT_TIMEOUT_SECS: u64 = 600;

Expand Down
34 changes: 23 additions & 11 deletions crates/goose/src/providers/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl OAuthFlow {
})
}

fn get_authorization_url(&self) -> String {
fn get_authorization_url_with_redirect(&self, redirect_url: &str) -> String {
let challenge = {
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
Expand All @@ -214,7 +214,7 @@ impl OAuthFlow {
let params = [
("response_type", "code"),
("client_id", &self.client_id),
("redirect_uri", &self.redirect_url),
("redirect_uri", redirect_url),
("scope", &self.scopes.join(" ")),
("state", &self.state),
("code_challenge", &challenge),
Expand All @@ -228,11 +228,15 @@ impl OAuthFlow {
)
}

async fn exchange_code_for_token(&self, code: &str) -> Result<TokenData> {
async fn exchange_code_for_token_with_redirect(
&self,
code: &str,
redirect_url: &str,
) -> Result<TokenData> {
let params = [
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", &self.redirect_url),
("redirect_uri", redirect_url),
("code_verifier", &self.verifier),
("client_id", &self.client_id),
];
Expand Down Expand Up @@ -323,19 +327,26 @@ impl OAuthFlow {
);

// Start the server to accept the oauth code
let redirect_url = Url::parse(&self.redirect_url)?;
let port = redirect_url.port().unwrap_or(80);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let redirect_url_parsed = Url::parse(&self.redirect_url)?;
let requested_port = redirect_url_parsed.port();

// If no port is specified (or port is explicitly 0), let the OS assign one
// Otherwise, use the requested port
let bind_port = requested_port.unwrap_or(0);
let addr = SocketAddr::from(([127, 0, 0, 1], bind_port));
let listener = tokio::net::TcpListener::bind(addr).await?;

let actual_port = listener.local_addr()?.port();

let server_handle = tokio::spawn(async move {
let server = axum::serve(listener, app);
server.await.unwrap();
});

let actual_redirect_url = format!("http://localhost:{}", actual_port);

// Open the browser which will redirect with the code to the server
let authorization_url = self.get_authorization_url();
let authorization_url = self.get_authorization_url_with_redirect(&actual_redirect_url);
if webbrowser::open(&authorization_url).is_err() {
println!(
"Please open this URL in your browser:\n{}",
Expand All @@ -354,8 +365,9 @@ impl OAuthFlow {
// Stop the server
server_handle.abort();

// Exchange the code for a token
self.exchange_code_for_token(&code).await
// Exchange the code for a token using the actual redirect URL
self.exchange_code_for_token_with_redirect(&code, &actual_redirect_url)
.await
}
}

Expand Down Expand Up @@ -531,7 +543,7 @@ mod tests {
let flow = OAuthFlow::new(
endpoints,
"test-client".to_string(),
"http://localhost:8020".to_string(),
"http://localhost".to_string(),
vec!["all-apis".to_string()],
);

Expand Down
Loading