diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 2c81294468c2..f70ee14d24ad 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -34,7 +34,7 @@ use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; -const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; +const DEFAULT_REDIRECT_URL: &str = "http://localhost"; const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; const DEFAULT_TIMEOUT_SECS: u64 = 600; diff --git a/crates/goose/src/providers/oauth.rs b/crates/goose/src/providers/oauth.rs index c0be3544bd12..c5d0271364f7 100644 --- a/crates/goose/src/providers/oauth.rs +++ b/crates/goose/src/providers/oauth.rs @@ -205,7 +205,7 @@ impl OAuthFlow { }) } - fn get_authorization_url(&self) -> String { + fn get_authorization_url_with_redirect(&self, redirect_url: &str) -> String { let challenge = { let digest = sha2::Sha256::digest(self.verifier.as_bytes()); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) @@ -214,7 +214,7 @@ impl OAuthFlow { let params = [ ("response_type", "code"), ("client_id", &self.client_id), - ("redirect_uri", &self.redirect_url), + ("redirect_uri", redirect_url), ("scope", &self.scopes.join(" ")), ("state", &self.state), ("code_challenge", &challenge), @@ -228,11 +228,15 @@ impl OAuthFlow { ) } - async fn exchange_code_for_token(&self, code: &str) -> Result { + async fn exchange_code_for_token_with_redirect( + &self, + code: &str, + redirect_url: &str, + ) -> Result { let params = [ ("grant_type", "authorization_code"), ("code", code), - ("redirect_uri", &self.redirect_url), + ("redirect_uri", redirect_url), ("code_verifier", &self.verifier), ("client_id", &self.client_id), ]; @@ -323,19 +327,26 @@ impl OAuthFlow { ); // Start the server to accept the oauth code - let redirect_url = Url::parse(&self.redirect_url)?; - let port = redirect_url.port().unwrap_or(80); - let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let redirect_url_parsed = Url::parse(&self.redirect_url)?; + let requested_port = redirect_url_parsed.port(); + // If no port is specified (or port is explicitly 0), let the OS assign one + // Otherwise, use the requested port + let bind_port = requested_port.unwrap_or(0); + let addr = SocketAddr::from(([127, 0, 0, 1], bind_port)); let listener = tokio::net::TcpListener::bind(addr).await?; + let actual_port = listener.local_addr()?.port(); + let server_handle = tokio::spawn(async move { let server = axum::serve(listener, app); server.await.unwrap(); }); + let actual_redirect_url = format!("http://localhost:{}", actual_port); + // Open the browser which will redirect with the code to the server - let authorization_url = self.get_authorization_url(); + let authorization_url = self.get_authorization_url_with_redirect(&actual_redirect_url); if webbrowser::open(&authorization_url).is_err() { println!( "Please open this URL in your browser:\n{}", @@ -354,8 +365,9 @@ impl OAuthFlow { // Stop the server server_handle.abort(); - // Exchange the code for a token - self.exchange_code_for_token(&code).await + // Exchange the code for a token using the actual redirect URL + self.exchange_code_for_token_with_redirect(&code, &actual_redirect_url) + .await } } @@ -531,7 +543,7 @@ mod tests { let flow = OAuthFlow::new( endpoints, "test-client".to_string(), - "http://localhost:8020".to_string(), + "http://localhost".to_string(), vec!["all-apis".to_string()], );