Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changesets/feat_david_support_traceparent_context.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Server adds support for incoming distributed trace context propagation - @david-castaneda PR #484

The MCP server now extracts W3C traceparent headers from incoming requests and uses this context for its own emitted traces, enabling handler spans to nest under parent traces for complete end-to-end observability.
1 change: 1 addition & 0 deletions crates/apollo-mcp-server/src/server/states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod operations_configured;
mod running;
mod schema_configured;
mod starting;
mod telemetry;

use configuring::Configuring;
use operations_configured::OperationsConfigured;
Expand Down
7 changes: 4 additions & 3 deletions crates/apollo-mcp-server/src/server/states/running.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::apps::find_and_execute_app;
use crate::generated::telemetry::{TelemetryAttribute, TelemetryMetric};
use crate::meter;
use crate::operations::{execute_operation, find_and_execute_operation};
use crate::server::states::telemetry::get_parent_span;
use crate::{
apps::AppResource,
custom_scalar_map::CustomScalarMap,
Expand Down Expand Up @@ -337,7 +338,7 @@ impl Running {
}

impl ServerHandler for Running {
#[tracing::instrument(skip_all, fields(apollo.mcp.client_name = request.client_info.name, apollo.mcp.client_version = request.client_info.version))]
#[tracing::instrument(skip_all, parent = get_parent_span(&context), fields(apollo.mcp.client_name = request.client_info.name, apollo.mcp.client_version = request.client_info.version))]
async fn initialize(
&self,
request: InitializeRequestParam,
Expand All @@ -364,7 +365,7 @@ impl ServerHandler for Running {
Ok(self.get_info())
}

#[tracing::instrument(skip_all, fields(apollo.mcp.tool_name = request.name.as_ref(), apollo.mcp.request_id = %context.id.clone()))]
#[tracing::instrument(skip_all, parent = get_parent_span(&context), fields(apollo.mcp.tool_name = request.name.as_ref(), apollo.mcp.request_id = %context.id.clone()))]
async fn call_tool(
&self,
request: CallToolRequestParam,
Expand Down Expand Up @@ -489,7 +490,7 @@ impl ServerHandler for Running {
result
}

#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, parent = get_parent_span(&context))]
async fn list_tools(
&self,
_request: Option<PaginatedRequestParam>,
Expand Down
38 changes: 4 additions & 34 deletions crates/apollo-mcp-server/src/server/states/starting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{net::SocketAddr, sync::Arc};
use apollo_compiler::{Name, Schema, ast::OperationType, validation::Valid};
use axum::{Router, extract::Query, http::StatusCode, response::Json, routing::get};
use axum_otel_metrics::HttpMetricsLayerBuilder;
use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer};
use axum_tracing_opentelemetry::middleware::OtelInResponseLayer;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
use rmcp::{
Expand All @@ -14,9 +14,9 @@ use rmcp::{
use serde_json::json;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
use tracing::{Instrument as _, debug, error, info, trace};

use crate::server::states::telemetry::otel_context_middleware;
use crate::{
errors::ServerError,
explorer::Explorer,
Expand Down Expand Up @@ -228,38 +228,8 @@ impl Starting {
.layer(HttpMetricsLayerBuilder::new().build())
// include trace context as header into the response
.layer(OtelInResponseLayer)
//start OpenTelemetry trace on incoming request
.layer(OtelAxumLayer::default())
// Add tower-http tracing layer for additional HTTP-level tracing
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &axum::http::Request<_>| {
tracing::info_span!(
"mcp_server",
method = %request.method(),
uri = %request.uri(),
session_id = tracing::field::Empty,
status_code = tracing::field::Empty,
)
})
.on_response(
|response: &axum::http::Response<_>,
_latency: std::time::Duration,
span: &tracing::Span| {
span.record(
"status_code",
tracing::field::display(response.status()),
);
if let Some(session_id) = response
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
{
span.record("session_id", tracing::field::display(session_id));
}
},
),
);
// start OpenTelemetry trace on incoming request
.layer(axum::middleware::from_fn(otel_context_middleware));

// Add health check endpoint if configured
if let Some(health_check) = health_check.filter(|h| h.config().enabled) {
Expand Down
172 changes: 172 additions & 0 deletions crates/apollo-mcp-server/src/server/states/telemetry.rs
Comment thread
david-castaneda marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use opentelemetry::global;
use opentelemetry::propagation::Extractor;
use rmcp::RoleServer;
use rmcp::service::RequestContext;
use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;

// Custom extractor for axum headers
struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);

// Implement the Extractor trait for HeaderExtractor
impl<'a> Extractor for HeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.to_str().ok())
}

fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_str()).collect()
}
}

// Middleware that extracts and stores OpenTelemetry context in request extensions
pub async fn otel_context_middleware(mut request: Request, next: Next) -> Response {
let parent_cx = global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(request.headers()))
});

request.extensions_mut().insert(parent_cx.clone()); // Store the OtelContext directly in extensions

let span = tracing::info_span!(
"mcp_server",
method = %request.method(),
uri = %request.uri(),
session_id = tracing::field::Empty,
status_code = tracing::field::Empty,
);
span.set_parent(parent_cx);

request.extensions_mut().insert(span.clone()); // Store the span in request extensions

let response = next.run(request).instrument(span.clone()).await;

span.record("status_code", tracing::field::display(response.status()));

if let Some(session_id) = response
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
{
span.record("session_id", tracing::field::display(session_id));
}

response
}

// Helper function to retrieve the parent span from the request context
pub fn get_parent_span(context: &RequestContext<RoleServer>) -> tracing::Span {
context
.extensions
.get::<axum::http::request::Parts>()
.and_then(|parts| parts.extensions.get::<tracing::Span>())
.cloned()
.unwrap_or_else(tracing::Span::none)
}

#[cfg(test)]
mod tests {
use super::*;
use axum::{Router, body::Body, http::Request, routing::get};
use http::HeaderName;
use opentelemetry::Context as OtelContext;
use opentelemetry::trace::TraceContextExt;
use tower::ServiceExt;

#[tokio::test()]
async fn test_middleware_stores_span_context_and_handler_works() {
opentelemetry::global::set_text_map_propagator(
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);

async fn test_handler(req: Request<Body>) -> &'static str {
let (parts, _body) = req.into_parts();

// Get OtelContext from extensions
let otel_ctx = parts
.extensions
.get::<OtelContext>()
.expect("OtelContext should be in extensions");

let trace_id = format!("{:032x}", otel_ctx.span().span_context().trace_id());
assert_eq!(trace_id, "4bf92f3577b34da6a3ce929d0e0e4736");

// Verify span is also stored
let span = parts.extensions.get::<tracing::Span>();
assert!(span.is_some());

"ok"
}

let app = Router::new()
.route("/test", get(test_handler))
.layer(axum::middleware::from_fn(otel_context_middleware));

let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let request = Request::builder()
.uri("/test")
.header("traceparent", traceparent)
.body(Body::empty())
.unwrap();

let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), 200);
}

#[tokio::test]
async fn test_middleware_works_without_traceparent() {
opentelemetry::global::set_text_map_propagator(
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);

let app = Router::new()
.route("/test", get(|| async { "ok" }))
.layer(axum::middleware::from_fn(otel_context_middleware));

let request = Request::builder().uri("/test").body(Body::empty()).unwrap();

let response = app.oneshot(request).await.unwrap();

assert_eq!(response.status(), 200);
}

#[test]
fn test_header_extractor_gets_values() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("traceparent", "test-value".parse().unwrap());
headers.insert("x-custom", "custom-value".parse().unwrap());

let extractor = HeaderExtractor(&headers);

assert_eq!(extractor.get("traceparent"), Some("test-value"));
assert_eq!(extractor.get("x-custom"), Some("custom-value"));
assert_eq!(extractor.get("missing"), None);
}

#[test]
fn test_header_extractor_keys() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("traceparent", "test-value".parse().unwrap());
headers.insert("x-custom", "custom-value".parse().unwrap());

let extractor = HeaderExtractor(&headers);

let mut keys = extractor
.keys()
.into_iter()
.map(|k| HeaderName::from_bytes(k.as_bytes()).unwrap())
.collect::<Vec<_>>();

let mut expected = vec![
HeaderName::from_static("traceparent"),
HeaderName::from_static("x-custom"),
];

keys.sort_by(|a, b| a.as_str().cmp(b.as_str()));
expected.sort_by(|a, b| a.as_str().cmp(b.as_str()));

assert_eq!(keys, expected);
}
}