Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,18 @@ impl StreamableHttpClientTransport<reqwest::Client> {
reqwest::Client::default(),
StreamableHttpClientTransportConfig {
uri: uri.into(),
auth_header: None,
..Default::default()
},
)
}

/// Build this transport form a config
///
/// # Arguments
///
/// * `config` - The config to use with this transport
pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self {
StreamableHttpClientTransport::with_client(reqwest::Client::default(), config)
}
}
30 changes: 27 additions & 3 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let _ = responder.send(Ok(()));
let (message, session_id) = self
.client
.post_message(config.uri.clone(), initialize_request, None, None)
.post_message(
config.uri.clone(),
initialize_request,
None,
self.config.auth_header,
)
.await
.map_err(WorkerQuitReason::fatal_context("send initialize request"))?
.expect_initialized::<C::Error>()
Expand Down Expand Up @@ -339,7 +344,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
config.uri.clone(),
initialized_notification.message,
session_id.clone(),
None,
config.auth_header.clone(),
)
.await
.map_err(WorkerQuitReason::fatal_context(
Expand Down Expand Up @@ -426,7 +431,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let WorkerSendRequest { message, responder } = send_request;
let response = self
.client
.post_message(config.uri.clone(), message, session_id.clone(), None)
.post_message(
config.uri.clone(),
message,
session_id.clone(),
config.auth_header.clone(),
)
.await;
let send_result = match response {
Err(e) => Err(e),
Expand Down Expand Up @@ -505,6 +515,8 @@ pub struct StreamableHttpClientTransportConfig {
pub channel_buffer_capacity: usize,
/// if true, the transport will not require a session to be established
pub allow_stateless: bool,
/// The value to send in the authorization header
pub auth_header: Option<String>,
}

impl StreamableHttpClientTransportConfig {
Expand All @@ -514,6 +526,17 @@ impl StreamableHttpClientTransportConfig {
..Default::default()
}
}

/// Set the authorization header to send with requests
///
/// # Arguments
///
/// * `value` - The value to set
pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
// set our authorization header
self.auth_header = Some(value.into());
self
}
}

impl Default for StreamableHttpClientTransportConfig {
Expand All @@ -523,6 +546,7 @@ impl Default for StreamableHttpClientTransportConfig {
retry_config: Arc::new(ExponentialBackoff::default()),
channel_buffer_capacity: 16,
allow_stateless: true,
auth_header: None,
}
}
}