Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions crates/goose/src/agents/extension_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl ExtensionManager {
);
}
let client = reqwest::Client::builder()
.default_headers(default_headers)
.default_headers(default_headers.clone())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not always send the bearer token if can acquire one along here? if we can acquire it, we need it anyway

.build()
.map_err(|_| {
ExtensionError::ConfigError("could not construct http client".to_string())
Expand All @@ -236,13 +236,83 @@ impl ExtensionManager {
..Default::default()
},
);
let client = McpClient::connect(
// Try initial connection
let client_result = McpClient::connect(
transport,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
)
.await?;
.await;

// Check if it's a 401 error and try OAuth authentication
let client = match client_result {
Ok(client) => client,
Err(error) => {
let error_str = error.to_string();

// Check for OAuth-requiring endpoints that fail with connection errors
let is_oauth_endpoint = uri.contains("mcp.notion.com") ||
uri.contains("notion.com") ||
uri.contains("oauth");

let should_try_oauth = (error_str.contains("401") && error_str.contains("Unauthorized")) ||
(is_oauth_endpoint &&
(error_str.contains("connection closed") ||
error_str.contains("initialize response")));

if should_try_oauth {
// Create OAuth config from MCP endpoint
match mcp_client::ServiceConfig::from_mcp_endpoint(uri) {
Ok(oauth_config) => {
// Attempt OAuth authentication
match mcp_client::authenticate_service(oauth_config, uri).await {
Ok(access_token) => {
// Add Authorization header with the token
let mut oauth_headers = default_headers;
oauth_headers.insert(
HeaderName::from_static("authorization"),
format!("Bearer {}", access_token).parse().unwrap()
);

let oauth_client = reqwest::Client::builder()
.default_headers(oauth_headers)
.build()
.map_err(|_| {
ExtensionError::ConfigError("could not construct oauth http client".to_string())
})?;

let oauth_transport = StreamableHttpClientTransport::with_client(
oauth_client,
StreamableHttpClientTransportConfig {
uri: uri.clone().into(),
..Default::default()
},
);

// Retry connection with OAuth token
McpClient::connect(
oauth_transport,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
)
.await?
}
Err(_oauth_error) => {
return Err(error.into());
}
}
}
Err(_config_error) => {
return Err(error.into());
}
}
} else {
return Err(error.into());
}
}
};
Box::new(client)
}
ExtensionConfig::Stdio {
Expand Down
Loading