Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 93 additions & 51 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::{future, FutureExt};
use once_cell::sync::Lazy;
use rmcp::service::{ClientInitializeError, ServiceError};
use rmcp::transport::streamable_http_client::{
AuthRequiredError, StreamableHttpClientTransportConfig, StreamableHttpError,
StreamableHttpClientTransportConfig, StreamableHttpError,
};
use rmcp::transport::{
ConfigureCommandExt, DynamicTransportError, StreamableHttpClientTransport, TokioChildProcess,
Expand Down Expand Up @@ -296,25 +296,26 @@ async fn child_process_client(
}
}

fn extract_auth_error(
res: &Result<McpClient, ClientInitializeError>,
) -> Option<&AuthRequiredError> {
match res {
Ok(_) => None,
Err(err) => match err {
ClientInitializeError::TransportError {
error: DynamicTransportError { error, .. },
..
} => error
.downcast_ref::<StreamableHttpError<reqwest::Error>>()
.and_then(|auth_error| match auth_error {
StreamableHttpError::AuthRequired(auth_required_error) => {
Some(auth_required_error)
}
_ => None,
}),
_ => None,
},
/// Retry with OAuth for typed auth challenges and wrapped bare HTTP 401 responses.
fn should_attempt_oauth_fallback(res: &Result<McpClient, ClientInitializeError>) -> bool {
let Err(ClientInitializeError::TransportError {
error: DynamicTransportError { error, .. },
..
}) = res
else {
return false;
};

if let Some(http_err) = error.downcast_ref::<StreamableHttpError<reqwest::Error>>() {
match http_err {
StreamableHttpError::AuthRequired(_) => true,
StreamableHttpError::UnexpectedServerResponse(body) => body.starts_with("HTTP 401"),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restrict OAuth fallback beyond generic HTTP 401

should_attempt_oauth_fallback now treats any UnexpectedServerResponse that starts with HTTP 401 as OAuth-capable, so ordinary unauthorized responses (for example, bad API-key/header credentials on non-OAuth servers) will trigger the OAuth flow path. In this code path, a failed oauth_flow replaces the original transport error with a generic setup error, which both misdirects users into an irrelevant auth flow and hides the actionable server response they need to fix credentials.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restrict OAuth fallback for generic HTTP 401 failures

should_attempt_oauth_fallback treats any UnexpectedServerResponse whose body starts with HTTP 401 as OAuth-capable, so non-OAuth servers that use API keys/basic auth and return plain 401s will still enter oauth_flow. In that scenario users are pushed through an irrelevant OAuth attempt (browser flow/discovery) before getting the original transport error, which is a behavior regression for unauthorized-but-non-OAuth endpoints and can significantly confuse authentication troubleshooting.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

@hydrosquall hydrosquall Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think after tracing through the rmcp source, this is not an issue. start_authorization calls discover_metadata().await? before any browser interaction. That method tries to fetch OAuth server metadata from the target server , so if the server doesn't support OAuth, both discovery attempts fail and the error propagates via ? out of oauth_flow before a browser ever opens. The call site then reaches Err(_) => Ok(Box::new(client_res?)) and returns the original 401 error.

In other words, for a non-OAuth server returning a bare 401, the worst case is two extra failing HTTP requests to the discovery endpoints. The browser will not open, and the caller still gets the original error back.

_ => false,
}
} else {
error
.to_string()
.contains("unexpected server response: HTTP 401")
}
}

Expand Down Expand Up @@ -453,37 +454,39 @@ async fn create_streamable_http_client(
)
.await;

if extract_auth_error(&client_res).is_some() {
let auth_manager = oauth_flow(&uri.to_string(), &name.to_string())
.await
.map_err(|_| ExtensionError::SetupError("auth error".to_string()))?;
let mut auth_headers = HeaderMap::new();
auth_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT);
let auth_http_client = reqwest::Client::builder()
.default_headers(auth_headers)
.build()
.map_err(|_| {
ExtensionError::ConfigError("could not construct http client".to_string())
})?;
let auth_client = AuthClient::new(auth_http_client, auth_manager);
let transport = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig {
uri: uri.into(),
..Default::default()
},
);
Ok(Box::new(
McpClient::connect(
transport,
timeout_duration,
provider,
client_name,
capabilities,
roots_dir.to_path_buf(),
)
.await?,
))
if should_attempt_oauth_fallback(&client_res) {
match oauth_flow(&uri.to_string(), &name.to_string()).await {
Ok(auth_manager) => {
let mut auth_headers = HeaderMap::new();
auth_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT);
let auth_http_client = reqwest::Client::builder()
.default_headers(auth_headers)
.build()
.map_err(|_| {
ExtensionError::ConfigError("could not construct http client".to_string())
})?;
let auth_client = AuthClient::new(auth_http_client, auth_manager);
let transport = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig {
uri: uri.into(),
..Default::default()
},
);
Ok(Box::new(
McpClient::connect(
transport,
timeout_duration,
provider,
client_name,
capabilities,
roots_dir.to_path_buf(),
)
.await?,
))
}
Err(_) => Ok(Box::new(client_res?)),
}
} else {
Ok(Box::new(client_res?))
}
Expand Down Expand Up @@ -2305,4 +2308,43 @@ mod tests {
"old extension must be preserved when replacement client creation fails"
);
}

fn transport_err(error: Box<dyn std::error::Error + Send + Sync>) -> ClientInitializeError {
ClientInitializeError::TransportError {
error: rmcp::transport::DynamicTransportError {
transport_name: "test".into(),
transport_type_id: std::any::TypeId::of::<()>(),
error,
},
context: "test context".into(),
}
}

fn streamable_err(
e: rmcp::transport::streamable_http_client::StreamableHttpError<reqwest::Error>,
) -> ClientInitializeError {
transport_err(Box::new(e))
}

#[test]
fn test_oauth_fallback_on_typed_auth_required() {
let err = streamable_err(
rmcp::transport::streamable_http_client::StreamableHttpError::AuthRequired(
rmcp::transport::streamable_http_client::AuthRequiredError {
www_authenticate_header: "Bearer realm=\"test\"".to_string(),
},
),
);
assert!(should_attempt_oauth_fallback(&Err(err)));
}

#[test]
fn test_oauth_fallback_on_unexpected_response_http_401_prefix() {
let err = streamable_err(
rmcp::transport::streamable_http_client::StreamableHttpError::UnexpectedServerResponse(
std::borrow::Cow::Borrowed("HTTP 401 Unauthorized"),
),
);
assert!(should_attempt_oauth_fallback(&Err(err)));
}
}
Loading