diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 6dd055de5277..bee360181c08 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -6,7 +6,7 @@ use futures::{future, FutureExt}; use once_cell::sync::Lazy; use rmcp::service::{ClientInitializeError, ServiceError}; use rmcp::transport::streamable_http_client::{ - AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError, + StreamableHttpClientTransportConfig, StreamableHttpError, }; use rmcp::transport::{ ConfigureCommandExt, DynamicTransportError, StreamableHttpClientTransport, TokioChildProcess, @@ -296,25 +296,26 @@ async fn child_process_client( } } -fn extract_auth_error( - res: &Result, -) -> Option<&AuthRequiredError> { - match res { - Ok(_) => None, - Err(err) => match err { - ClientInitializeError::TransportError { - error: DynamicTransportError { error, .. }, - .. - } => error - .downcast_ref::>() - .and_then(|auth_error| match auth_error { - StreamableHttpError::AuthRequired(auth_required_error) => { - Some(auth_required_error) - } - _ => None, - }), - _ => None, - }, +/// Retry with OAuth for typed auth challenges and wrapped bare HTTP 401 responses. +fn should_attempt_oauth_fallback(res: &Result) -> bool { + let Err(ClientInitializeError::TransportError { + error: DynamicTransportError { error, .. }, + .. + }) = res + else { + return false; + }; + + if let Some(http_err) = error.downcast_ref::>() { + match http_err { + StreamableHttpError::AuthRequired(_) => true, + StreamableHttpError::UnexpectedServerResponse(body) => body.starts_with("HTTP 401"), + _ => false, + } + } else { + error + .to_string() + .contains("unexpected server response: HTTP 401") } } @@ -453,37 +454,39 @@ async fn create_streamable_http_client( ) .await; - if extract_auth_error(&client_res).is_some() { - let auth_manager = oauth_flow(&uri.to_string(), &name.to_string()) - .await - .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?; - let mut auth_headers = HeaderMap::new(); - auth_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT); - let auth_http_client = reqwest::Client::builder() - .default_headers(auth_headers) - .build() - .map_err(|_| { - ExtensionError::ConfigError("could not construct http client".to_string()) - })?; - let auth_client = AuthClient::new(auth_http_client, auth_manager); - let transport = StreamableHttpClientTransport::with_client( - auth_client, - StreamableHttpClientTransportConfig { - uri: uri.into(), - ..Default::default() - }, - ); - Ok(Box::new( - McpClient::connect( - transport, - timeout_duration, - provider, - client_name, - capabilities, - roots_dir.to_path_buf(), - ) - .await?, - )) + if should_attempt_oauth_fallback(&client_res) { + match oauth_flow(&uri.to_string(), &name.to_string()).await { + Ok(auth_manager) => { + let mut auth_headers = HeaderMap::new(); + auth_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT); + let auth_http_client = reqwest::Client::builder() + .default_headers(auth_headers) + .build() + .map_err(|_| { + ExtensionError::ConfigError("could not construct http client".to_string()) + })?; + let auth_client = AuthClient::new(auth_http_client, auth_manager); + let transport = StreamableHttpClientTransport::with_client( + auth_client, + StreamableHttpClientTransportConfig { + uri: uri.into(), + ..Default::default() + }, + ); + Ok(Box::new( + McpClient::connect( + transport, + timeout_duration, + provider, + client_name, + capabilities, + roots_dir.to_path_buf(), + ) + .await?, + )) + } + Err(_) => Ok(Box::new(client_res?)), + } } else { Ok(Box::new(client_res?)) } @@ -2305,4 +2308,43 @@ mod tests { "old extension must be preserved when replacement client creation fails" ); } + + fn transport_err(error: Box) -> ClientInitializeError { + ClientInitializeError::TransportError { + error: rmcp::transport::DynamicTransportError { + transport_name: "test".into(), + transport_type_id: std::any::TypeId::of::<()>(), + error, + }, + context: "test context".into(), + } + } + + fn streamable_err( + e: rmcp::transport::streamable_http_client::StreamableHttpError, + ) -> ClientInitializeError { + transport_err(Box::new(e)) + } + + #[test] + fn test_oauth_fallback_on_typed_auth_required() { + let err = streamable_err( + rmcp::transport::streamable_http_client::StreamableHttpError::AuthRequired( + rmcp::transport::streamable_http_client::AuthRequiredError { + www_authenticate_header: "Bearer realm=\"test\"".to_string(), + }, + ), + ); + assert!(should_attempt_oauth_fallback(&Err(err))); + } + + #[test] + fn test_oauth_fallback_on_unexpected_response_http_401_prefix() { + let err = streamable_err( + rmcp::transport::streamable_http_client::StreamableHttpError::UnexpectedServerResponse( + std::borrow::Cow::Borrowed("HTTP 401 Unauthorized"), + ), + ); + assert!(should_attempt_oauth_fallback(&Err(err))); + } }