Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
async-trait = "0.1.89"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how big a concern it is to have this, but if you want to avoid adding this dependency you can have the trait return Futures

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I often use this dependency and I am planning to introduce it in the future to make the code more concise. Do you want to change other implementations to this dependency implementation as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use it more - it seems nice

serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "2"
Expand Down
141 changes: 118 additions & 23 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc, time::Duration};

use async_trait::async_trait;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
Expand All @@ -17,6 +18,62 @@ use tracing::{debug, error, warn};

const DEFAULT_EXCHANGE_URL: &str = "http://localhost";

/// Stored credentials for OAuth2 authorization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub client_id: String,
pub token_response: Option<OAuthTokenResponse>,
}

/// Trait for storing and retrieving OAuth2 credentials
///
/// Implementations of this trait can provide custom storage backends
/// for OAuth2 credentials, such as file-based storage, keychain integration,
/// or database storage.
#[async_trait]
pub trait CredentialStore: Send + Sync {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError>;

async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError>;

async fn clear(&self) -> Result<(), AuthError>;
}

/// In-memory credential store (default implementation)
///
/// This store keeps credentials in memory only and does not persist them
/// between application restarts. This is the default behavior when no
/// custom credential store is provided.
#[derive(Debug, Default, Clone)]
pub struct InMemoryCredentialStore {
credentials: Arc<RwLock<Option<StoredCredentials>>>,
}

impl InMemoryCredentialStore {
pub fn new() -> Self {
Self {
credentials: Arc::new(RwLock::new(None)),
}
}
}

#[async_trait::async_trait]
impl CredentialStore for InMemoryCredentialStore {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
Ok(self.credentials.read().await.clone())
}

async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
*self.credentials.write().await = Some(credentials);
Ok(())
}

async fn clear(&self) -> Result<(), AuthError> {
*self.credentials.write().await = None;
Ok(())
}
}

/// sse client with oauth2 authorization
#[derive(Clone)]
pub struct AuthClient<C> {
Expand Down Expand Up @@ -151,7 +208,7 @@ pub struct AuthorizationManager {
http_client: HttpClient,
metadata: Option<AuthorizationMetadata>,
oauth_client: Option<OAuthClient>,
credentials: RwLock<Option<OAuthTokenResponse>>,
credential_store: Arc<dyn CredentialStore>,
state: RwLock<Option<AuthorizationState>>,
base_url: Url,
}
Expand Down Expand Up @@ -222,14 +279,42 @@ impl AuthorizationManager {
http_client,
metadata: None,
oauth_client: None,
credentials: RwLock::new(None),
credential_store: Arc::new(InMemoryCredentialStore::new()),
state: RwLock::new(None),
base_url,
};

Ok(manager)
}

/// Set a custom credential store
///
/// This allows you to provide your own implementation of credential storage,
/// such as file-based storage, keychain integration, or database storage.
/// This should be called before any operations that need credentials.
pub fn set_credential_store<S: CredentialStore + 'static>(&mut self, store: S) {
self.credential_store = Arc::new(store);
}

/// Initialize from stored credentials if available
///
/// This will load credentials from the credential store and configure
/// the client if credentials are found.
pub async fn initialize_from_store(&mut self) -> Result<bool, AuthError> {
if let Some(stored) = self.credential_store.load().await? {
if stored.token_response.is_some() {
if self.metadata.is_none() {
let metadata = self.discover_metadata().await?;
self.metadata = Some(metadata);
}

self.configure_client_id(&stored.client_id)?;
return Ok(true);
}
}
Ok(false)
}

pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> {
self.http_client = http_client;
Ok(())
Expand All @@ -252,13 +337,16 @@ impl AuthorizationManager {

/// get client id and credentials
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
let credentials = self.credentials.read().await;
let client_id = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
.client_id();
Ok((client_id.to_string(), credentials.clone()))

let stored = self.credential_store.load().await?;
let token_response = stored.and_then(|s| s.token_response);

Ok((client_id.to_string(), token_response))
}

/// configure oauth2 client with client credentials
Expand Down Expand Up @@ -309,7 +397,6 @@ impl AuthorizationManager {
));
};

// prepare registration request
let registration_request = ClientRegistrationRequest {
client_name: name.to_string(),
redirect_uris: vec![redirect_uri.to_string()],
Expand Down Expand Up @@ -479,23 +566,28 @@ impl AuthorizationManager {
};

debug!("exchange token result: {:?}", token_result);
// store credentials
*self.credentials.write().await = Some(token_result.clone());

// Store credentials in the credential store
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
};
self.credential_store.save(stored).await?;

Ok(token_result)
}

/// get access token, if expired, refresh it automatically
pub async fn get_access_token(&self) -> Result<String, AuthError> {
let credentials = self.credentials.read().await;
// Load credentials from store
let stored = self.credential_store.load().await?;
let credentials = stored.and_then(|s| s.token_response);

if let Some(creds) = credentials.as_ref() {
// check if the token is expire
let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0));
if expires_in <= Duration::from_secs(0) {
tracing::info!("Access token expired, refreshing.");
// token expired, try to refresh , release the lock
drop(credentials);

let new_creds = self.refresh_token().await?;
tracing::info!("Refreshed access token.");
Expand All @@ -517,26 +609,28 @@ impl AuthorizationManager {
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;

let current_credentials = self
.credentials
.read()
.await
.clone()
let stored = self.credential_store.load().await?;
let current_credentials = stored
.and_then(|s| s.token_response)
.ok_or_else(|| AuthError::AuthorizationRequired)?;

let refresh_token = current_credentials.refresh_token().ok_or_else(|| {
AuthError::TokenRefreshFailed("No refresh token available".to_string())
})?;
debug!("refresh token: {:?}", refresh_token);
// refresh token

let token_result = oauth_client
.exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string()))
.request_async(&self.http_client)
.await
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;

// store new credentials
*self.credentials.write().await = Some(token_result.clone());
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
};
self.credential_store.save(stored).await?;

Ok(token_result)
}
Expand Down Expand Up @@ -1003,14 +1097,15 @@ impl OAuthState {
AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?,
);

// write credentials
*manager.credentials.write().await = Some(credentials);
let stored = StoredCredentials {
client_id: client_id.to_string(),
token_response: Some(credentials),
};
manager.credential_store.save(stored).await?;

// discover metadata
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);

// set client id and secret
manager.configure_client_id(client_id)?;

*self = OAuthState::Authorized(manager);
Expand Down