Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/dc-mcp-server/src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ pub trait Executable {
let response = client
.post(request.endpoint.as_str())
.headers(self.headers(&request.headers))
.header("Content-Type", "application/json")
.body(Value::Object(request_body).to_string())
.send()
.await
Expand Down
19 changes: 14 additions & 5 deletions crates/dc-mcp-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::path::PathBuf;
use std::sync::Arc;

use apollo_mcp_registry::platform_api::operation_collections::collection_poller::CollectionSource;
use apollo_mcp_registry::uplink::persisted_queries::ManifestSource;
Expand All @@ -12,6 +13,7 @@ use dc_mcp_server::operations::OperationSource;
use dc_mcp_server::server::Server;
use dc_mcp_server::startup;
use runtime::IdOrDefault;
use tokio::sync::RwLock;
use tracing::{info, warn};

mod runtime;
Expand Down Expand Up @@ -41,7 +43,7 @@ async fn main() -> anyhow::Result<()> {
let config_path = args.config.clone();

// Read config for initial setup (telemetry)
let mut config: runtime::Config = match config_path.clone() {
let config: runtime::Config = match config_path.clone() {
Some(ref path) => runtime::read_config(path.clone())?,
None => runtime::read_config_from_env().unwrap_or_default(),
};
Expand All @@ -53,6 +55,9 @@ async fn main() -> anyhow::Result<()> {
env!("CARGO_PKG_VERSION")
);

// Create shared headers that can be updated by token refresh
let shared_headers = Arc::new(RwLock::new(config.headers.clone()));

// Check if token refresh is enabled
if startup::is_token_refresh_enabled() {
if let (Some(refresh_token), Some(refresh_url), Some(graphql_endpoint), Some(config_file)) = (
Expand All @@ -67,14 +72,14 @@ async fn main() -> anyhow::Result<()> {
refresh_token,
refresh_url,
graphql_endpoint,
Arc::clone(&shared_headers),
)
.await
{
warn!("Token refresh initialization failed: {}", e);
} else {
// Re-read config to get the refreshed token
info!("Re-reading config file to load refreshed token...");
config = runtime::read_config(config_file.clone())?;
// Token has been refreshed and shared_headers updated by initialize_with_token_refresh
info!("✅ Token refresh initialization complete");
}
} else {
warn!("Token refresh enabled but missing required environment variables");
Expand Down Expand Up @@ -141,13 +146,17 @@ async fn main() -> anyhow::Result<()> {

let transport = config.transport.clone();

// Read current headers from shared state
let current_headers = shared_headers.read().await.clone();

Ok(Server::builder()
.transport(config.transport)
.schema_source(schema_source)
.operation_source(operation_source)
.endpoint(config.endpoint.into_inner())
.maybe_explorer_graph_ref(explorer_graph_ref)
.headers(config.headers)
.headers(current_headers)
.maybe_shared_headers(Some(shared_headers))
.execute_introspection(config.introspection.execute.enabled)
.validate_introspection(config.introspection.validate.enabled)
.introspect_introspection(config.introspection.introspect.enabled)
Expand Down
5 changes: 5 additions & 0 deletions crates/dc-mcp-server/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;

use apollo_mcp_registry::uplink::schema::SchemaSource;
use bon::bon;
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use schemars::JsonSchema;
use serde::Deserialize;
use tokio::sync::RwLock;
use url::Url;

use crate::auth;
Expand All @@ -26,6 +28,7 @@ pub struct Server {
operation_source: OperationSource,
endpoint: Url,
headers: HeaderMap,
shared_headers: Option<Arc<RwLock<HeaderMap>>>,
execute_introspection: bool,
validate_introspection: bool,
introspect_introspection: bool,
Expand Down Expand Up @@ -111,6 +114,7 @@ impl Server {
operation_source: OperationSource,
endpoint: Url,
headers: HeaderMap,
#[builder(into)] shared_headers: Option<Arc<RwLock<HeaderMap>>>,
execute_introspection: bool,
validate_introspection: bool,
introspect_introspection: bool,
Expand Down Expand Up @@ -139,6 +143,7 @@ impl Server {
operation_source,
endpoint,
headers,
shared_headers,
execute_introspection,
validate_introspection,
introspect_introspection,
Expand Down
5 changes: 5 additions & 0 deletions crates/dc-mcp-server/src/server/states.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use apollo_compiler::{Schema, validation::Valid};
use apollo_federation::{ApiSchemaOptions, Supergraph};
use apollo_mcp_registry::uplink::schema::{SchemaState, event::Event as SchemaEvent};
use futures::{FutureExt as _, Stream, StreamExt as _, stream};
use reqwest::header::HeaderMap;
use tokio::sync::RwLock;
use url::Url;

use crate::{
Expand Down Expand Up @@ -34,6 +37,7 @@ struct Config {
transport: Transport,
endpoint: Url,
headers: HeaderMap,
shared_headers: Option<Arc<RwLock<HeaderMap>>>,
execute_introspection: bool,
validate_introspection: bool,
introspect_introspection: bool,
Expand Down Expand Up @@ -68,6 +72,7 @@ impl StateMachine {
transport: server.transport,
endpoint: server.endpoint,
headers: server.headers,
shared_headers: server.shared_headers,
execute_introspection: server.execute_introspection,
validate_introspection: server.validate_introspection,
introspect_introspection: server.introspect_introspection,
Expand Down
8 changes: 4 additions & 4 deletions crates/dc-mcp-server/src/server/states/running.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::{
pub(super) struct Running {
pub(super) schema: Arc<Mutex<Valid<Schema>>>,
pub(super) operations: Arc<Mutex<Vec<Operation>>>,
pub(super) headers: HeaderMap,
pub(super) headers: Arc<RwLock<HeaderMap>>,
pub(super) endpoint: Url,
pub(super) execute_tool: Option<Execute>,
pub(super) introspect_tool: Option<Introspect>,
Expand Down Expand Up @@ -235,7 +235,7 @@ impl ServerHandler for Running {
.await
}
EXECUTE_TOOL_NAME => {
let mut headers = self.headers.clone();
let mut headers = self.headers.read().await.clone();
if let Some(axum_parts) = context.extensions.get::<axum::http::request::Parts>() {
// Optionally extract the validated token and propagate it to upstream servers if present
if !self.disable_auth_token_passthrough
Expand Down Expand Up @@ -268,7 +268,7 @@ impl ServerHandler for Running {
.await
}
_ => {
let mut headers = self.headers.clone();
let mut headers = self.headers.read().await.clone();
if let Some(axum_parts) = context.extensions.get::<axum::http::request::Parts>() {
// Optionally extract the validated token and propagate it to upstream servers if present
if !self.disable_auth_token_passthrough
Expand Down Expand Up @@ -407,7 +407,7 @@ mod tests {
let running = Running {
schema: Arc::new(Mutex::new(schema)),
operations: Arc::new(Mutex::new(vec![])),
headers: HeaderMap::new(),
headers: Arc::new(RwLock::new(HeaderMap::new())),
endpoint: "http://localhost:4000".parse().unwrap(),
execute_tool: None,
introspect_tool: None,
Expand Down
6 changes: 5 additions & 1 deletion crates/dc-mcp-server/src/server/states/starting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ impl Starting {
let running = Running {
schema,
operations: Arc::new(Mutex::new(operations)),
headers: self.config.headers,
headers: self
.config
.shared_headers
.unwrap_or_else(|| Arc::new(RwLock::new(self.config.headers))),
endpoint: self.config.endpoint,
execute_tool,
introspect_tool,
Expand Down Expand Up @@ -355,6 +358,7 @@ mod tests {
mutation_mode: MutationMode::All,
execute_introspection: true,
headers: HeaderMap::new(),
shared_headers: None,
validate_introspection: true,
introspect_introspection: true,
search_introspection: true,
Expand Down
6 changes: 5 additions & 1 deletion crates/dc-mcp-server/src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
use crate::config_manager::ConfigManager;
use crate::errors::McpError;
use crate::token_manager::TokenManager;
use reqwest::header::HeaderMap;
use rmcp::model::ErrorCode;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};

/// Initialize the Apollo MCP Server with token refresh and environment setup
Expand All @@ -14,6 +16,7 @@ pub async fn initialize_with_token_refresh(
refresh_token: String,
refresh_url: String,
graphql_endpoint: String,
shared_headers: Arc<RwLock<HeaderMap>>,
) -> Result<(), McpError> {
info!("🎯 Apollo MCP Server initializing with token refresh...");

Expand All @@ -24,9 +27,10 @@ pub async fn initialize_with_token_refresh(
e
})?;

// Step 2: Initialize token manager with injected config manager
// Step 2: Initialize token manager with injected config manager and headers
let mut token_manager = TokenManager::new(refresh_token, refresh_url)?;
token_manager.set_config_manager(Arc::clone(&config_manager));
token_manager.set_headers(Arc::clone(&shared_headers));

// Step 3: Get fresh token (will automatically write to config)
let new_token = token_manager.get_valid_token().await.map_err(|e| {
Expand Down
23 changes: 23 additions & 0 deletions crates/dc-mcp-server/src/token_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
use crate::config_manager::ConfigManager;
use crate::errors::McpError;
use reqwest::Client;
use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
use rmcp::model::ErrorCode;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time::sleep;
use tracing::{debug, error, info, warn};

Expand All @@ -31,6 +33,7 @@ pub struct TokenManager {
token_expires_at: Option<Instant>,
client: Client,
config_manager: Option<Arc<ConfigManager>>,
headers: Option<Arc<RwLock<HeaderMap>>>,
}

impl TokenManager {
Expand Down Expand Up @@ -74,6 +77,7 @@ impl TokenManager {
token_expires_at: None,
client,
config_manager: None,
headers: None,
})
}

Expand All @@ -82,6 +86,11 @@ impl TokenManager {
self.config_manager = Some(config_manager);
}

/// Inject the shared headers for automatic token updates
pub fn set_headers(&mut self, headers: Arc<RwLock<HeaderMap>>) {
self.headers = Some(headers);
}

/// Get a valid access token, refreshing if necessary
pub async fn get_valid_token(&mut self) -> Result<String, McpError> {
// Check if we have a valid token
Expand Down Expand Up @@ -175,6 +184,19 @@ impl TokenManager {
}
}

// Update the shared headers if available
if let Some(headers) = &self.headers {
let mut headers_guard = headers.write().await;
if let Ok(header_value) =
HeaderValue::from_str(&format!("Bearer {}", token_response.access_token))
{
headers_guard.insert(AUTHORIZATION, header_value);
info!("✅ Refreshed token updated in shared headers");
} else {
warn!("Failed to create header value from token");
}
}

Ok(token_response.access_token)
}

Expand Down Expand Up @@ -266,6 +288,7 @@ impl Clone for TokenManager {
token_expires_at: self.token_expires_at,
client: self.client.clone(),
config_manager: self.config_manager.clone(),
headers: self.headers.clone(),
}
}
}
Expand Down