From a34436e4fad76e71492e9a1f85ffa0adc021eda2 Mon Sep 17 00:00:00 2001 From: jh-nv Date: Thu, 8 Jan 2026 19:12:24 -0500 Subject: [PATCH] fix: distributed tracing propagation for TCP transport (#5283) Co-authored-by: Ishan Dhanani --- lib/runtime/src/logging.rs | 142 +++++++++++++++++- lib/runtime/src/pipeline/network/codec.rs | 129 +++++++++++++++- .../src/pipeline/network/egress/tcp_client.rs | 36 ++++- .../network/ingress/shared_tcp_endpoint.rs | 32 +++- 4 files changed, 317 insertions(+), 22 deletions(-) diff --git a/lib/runtime/src/logging.rs b/lib/runtime/src/logging.rs index 7d02fb65849..52c42fcb278 100644 --- a/lib/runtime/src/logging.rs +++ b/lib/runtime/src/logging.rs @@ -278,6 +278,8 @@ pub fn make_request_span(req: &Request) -> Span { let version = format!("{:?}", req.version()); let trace_parent = TraceParent::from_headers(req.headers()); + let otel_context = extract_otel_context_from_http_headers(req.headers()); + let span = tracing::info_span!( "http-request", method = %method, @@ -286,12 +288,52 @@ pub fn make_request_span(req: &Request) -> Span { trace_id = trace_parent.trace_id, parent_id = trace_parent.parent_id, x_request_id = trace_parent.x_request_id, - x_dynamo_request_id = trace_parent.x_dynamo_request_id, + x_dynamo_request_id = trace_parent.x_dynamo_request_id, ); + if let Some(context) = otel_context { + let _ = span.set_parent(context); + } + span } +/// Extract OpenTelemetry context from HTTP headers for distributed tracing +fn extract_otel_context_from_http_headers( + headers: &http::HeaderMap, +) -> Option { + let traceparent_value = headers.get("traceparent")?.to_str().ok()?; + + struct HttpHeaderExtractor<'a>(&'a http::HeaderMap); + + impl<'a> Extractor for HttpHeaderExtractor<'a> { + fn get(&self, key: &str) -> Option<&str> { + self.0.get(key).and_then(|v| v.to_str().ok()) + } + + fn keys(&self) -> Vec<&str> { + vec!["traceparent", "tracestate"] + .into_iter() + .filter(|&key| self.0.get(key).is_some()) + .collect() + } + } + + // Early return if traceparent is empty + if traceparent_value.is_empty() { + return None; + } + + let extractor = HttpHeaderExtractor(headers); + let otel_context = TRACE_PROPAGATOR.extract(&extractor); + + if otel_context.span().span_context().is_valid() { + Some(otel_context) + } else { + None + } +} + /// Create a handle_payload span from NATS headers with component context pub fn make_handle_payload_span( headers: &async_nats::HeaderMap, @@ -335,6 +377,93 @@ pub fn make_handle_payload_span( } } +/// Create a handle_payload span from TCP/HashMap headers with component context +pub fn make_handle_payload_span_from_tcp_headers( + headers: &std::collections::HashMap, + component: &str, + endpoint: &str, + namespace: &str, + instance_id: u64, +) -> Span { + let (otel_context, trace_id, parent_span_id) = extract_otel_context_from_tcp_headers(headers); + let x_request_id = headers.get("x-request-id").cloned(); + let x_dynamo_request_id = headers.get("x-dynamo-request-id").cloned(); + let tracestate = headers.get("tracestate").cloned(); + + if let (Some(trace_id), Some(parent_id)) = (trace_id.as_ref(), parent_span_id.as_ref()) { + let span = tracing::info_span!( + "handle_payload", + trace_id = trace_id.as_str(), + parent_id = parent_id.as_str(), + x_request_id = x_request_id, + x_dynamo_request_id = x_dynamo_request_id, + tracestate = tracestate, + component = component, + endpoint = endpoint, + namespace = namespace, + instance_id = instance_id, + ); + + if let Some(context) = otel_context { + let _ = span.set_parent(context); + } + span + } else { + tracing::info_span!( + "handle_payload", + x_request_id = x_request_id, + x_dynamo_request_id = x_dynamo_request_id, + tracestate = tracestate, + component = component, + endpoint = endpoint, + namespace = namespace, + instance_id = instance_id, + ) + } +} + +/// Extract OpenTelemetry trace context from TCP/HashMap headers for distributed tracing +fn extract_otel_context_from_tcp_headers( + headers: &std::collections::HashMap, +) -> ( + Option, + Option, + Option, +) { + let traceparent_value = match headers.get("traceparent") { + Some(value) => value.as_str(), + None => return (None, None, None), + }; + + let (trace_id, parent_span_id) = parse_traceparent(traceparent_value); + + struct TcpHeaderExtractor<'a>(&'a std::collections::HashMap); + + impl<'a> Extractor for TcpHeaderExtractor<'a> { + fn get(&self, key: &str) -> Option<&str> { + self.0.get(key).map(|s| s.as_str()) + } + + fn keys(&self) -> Vec<&str> { + vec!["traceparent", "tracestate"] + .into_iter() + .filter(|&key| self.0.get(key).is_some()) + .collect() + } + } + + let extractor = TcpHeaderExtractor(headers); + let otel_context = TRACE_PROPAGATOR.extract(&extractor); + + let context_with_trace = if otel_context.span().span_context().is_valid() { + Some(otel_context) + } else { + None + }; + + (context_with_trace, trace_id, parent_span_id) +} + /// Extract OpenTelemetry trace context from NATS headers for distributed tracing pub fn extract_otel_context_from_nats_headers( headers: &async_nats::HeaderMap, @@ -366,8 +495,7 @@ pub fn extract_otel_context_from_nats_headers( } let extractor = NatsHeaderExtractor(headers); - let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new(); - let otel_context = propagator.extract(&extractor); + let otel_context = TRACE_PROPAGATOR.extract(&extractor); let context_with_trace = if otel_context.span().span_context().is_valid() { Some(otel_context) @@ -394,8 +522,7 @@ pub fn inject_otel_context_into_nats_headers( } let mut injector = NatsHeaderInjector(headers); - let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new(); - propagator.inject_context(&otel_context, &mut injector); + TRACE_PROPAGATOR.inject_context(&otel_context, &mut injector); } /// Inject trace context from current span into NATS headers @@ -948,6 +1075,11 @@ impl CustomJsonFormatter { use once_cell::sync::Lazy; use regex::Regex; + +/// Static W3C Trace Context propagator instance to avoid repeated allocations +static TRACE_PROPAGATOR: Lazy = + Lazy::new(opentelemetry_sdk::propagation::TraceContextPropagator::new); + fn parse_tracing_duration(s: &str) -> Option { static RE: Lazy = Lazy::new(|| Regex::new(r#"^["']?\s*([0-9.]+)\s*(µs|us|ns|ms|s)\s*["']?$"#).unwrap()); diff --git a/lib/runtime/src/pipeline/network/codec.rs b/lib/runtime/src/pipeline/network/codec.rs index 0452342d9d8..f5ece135944 100644 --- a/lib/runtime/src/pipeline/network/codec.rs +++ b/lib/runtime/src/pipeline/network/codec.rs @@ -18,16 +18,19 @@ mod two_part; pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType}; -/// TCP request plane protocol message with endpoint routing +/// TCP request plane protocol message with endpoint routing and trace headers /// /// Wire format: /// - endpoint_path_len: u16 (big-endian) /// - endpoint_path: UTF-8 string +/// - headers_len: u16 (big-endian) +/// - headers: JSON-encoded HashMap /// - payload_len: u32 (big-endian) /// - payload: bytes #[derive(Debug, Clone, PartialEq, Eq)] pub struct TcpRequestMessage { pub endpoint_path: String, + pub headers: std::collections::HashMap, pub payload: Bytes, } @@ -35,6 +38,19 @@ impl TcpRequestMessage { pub fn new(endpoint_path: String, payload: Bytes) -> Self { Self { endpoint_path, + headers: std::collections::HashMap::new(), + payload, + } + } + + pub fn with_headers( + endpoint_path: String, + headers: std::collections::HashMap, + payload: Bytes, + ) -> Self { + Self { + endpoint_path, + headers, payload, } } @@ -51,6 +67,22 @@ impl TcpRequestMessage { )); } + // Encode headers as JSON + let headers_json = serde_json::to_vec(&self.headers).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Failed to encode headers: {}", e), + ) + })?; + let headers_len = headers_json.len(); + + if headers_len > u16::MAX as usize { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Headers too large: {} bytes", headers_len), + )); + } + if self.payload.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -59,7 +91,8 @@ impl TcpRequestMessage { } // Use BytesMut for efficient buffer building - let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len()); + let mut buf = + BytesMut::with_capacity(2 + endpoint_len + 2 + headers_len + 4 + self.payload.len()); // Write endpoint path length (2 bytes) buf.put_u16(endpoint_len as u16); @@ -67,6 +100,12 @@ impl TcpRequestMessage { // Write endpoint path buf.put_slice(endpoint_bytes); + // Write headers length (2 bytes) + buf.put_u16(headers_len as u16); + + // Write headers + buf.put_slice(&headers_json); + // Write payload length (4 bytes) buf.put_u32(self.payload.len() as u32); @@ -102,11 +141,39 @@ impl TcpRequestMessage { .map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, - format!("Invalid UTF-8: {}", e), + format!("Invalid UTF-8 in endpoint path: {}", e), ) })?; offset += endpoint_len; + if bytes.len() < offset + 2 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Not enough bytes for headers length", + )); + } + + // Read headers length (2 bytes) + let headers_len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]) as usize; + offset += 2; + + if bytes.len() < offset + headers_len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Not enough bytes for headers", + )); + } + + // Read and parse headers + let headers: std::collections::HashMap = + serde_json::from_slice(&bytes[offset..offset + headers_len]).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid JSON in headers: {}", e), + ) + })?; + offset += headers_len; + if bytes.len() < offset + 4 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -139,6 +206,7 @@ impl TcpRequestMessage { Ok(Self { endpoint_path, + headers, payload, }) } @@ -169,14 +237,25 @@ impl Decoder for TcpRequestCodec { // Peek at endpoint path length without consuming let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize; - let header_size = 2 + endpoint_len + 4; // path_len + path + payload_len + // Need path + headers_len + if src.len() < 2 + endpoint_len + 2 { + return Ok(None); + } + + // Peek at headers length + let headers_len_offset = 2 + endpoint_len; + let headers_len = + u16::from_be_bytes([src[headers_len_offset], src[headers_len_offset + 1]]) as usize; + + // Need path + headers + payload_len + let header_size = 2 + endpoint_len + 2 + headers_len + 4; if src.len() < header_size { return Ok(None); } // Peek at payload length - let payload_len_offset = 2 + endpoint_len; + let payload_len_offset = 2 + endpoint_len + 2 + headers_len; let payload_len = u32::from_be_bytes([ src[payload_len_offset], src[payload_len_offset + 1], @@ -204,7 +283,7 @@ impl Decoder for TcpRequestCodec { return Ok(None); } - // We have a complete message, advance past length prefix + // We have a complete message, advance past endpoint path length prefix src.advance(2); // Read endpoint path @@ -216,6 +295,19 @@ impl Decoder for TcpRequestCodec { ) })?; + // Advance past headers length + src.advance(2); + + // Read and parse headers + let headers_bytes = src.split_to(headers_len); + let headers: std::collections::HashMap = + serde_json::from_slice(&headers_bytes).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid JSON in headers: {}", e), + ) + })?; + // Advance past payload length src.advance(4); @@ -224,6 +316,7 @@ impl Decoder for TcpRequestCodec { Ok(Some(TcpRequestMessage { endpoint_path, + headers, payload, })) } @@ -243,6 +336,22 @@ impl Encoder for TcpRequestCodec { )); } + // Encode headers as JSON + let headers_json = serde_json::to_vec(&item.headers).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Failed to encode headers: {}", e), + ) + })?; + let headers_len = headers_json.len(); + + if headers_len > u16::MAX as usize { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Headers too large: {} bytes", headers_len), + )); + } + if item.payload.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -250,7 +359,7 @@ impl Encoder for TcpRequestCodec { )); } - let total_len = 2 + endpoint_len + 4 + item.payload.len(); + let total_len = 2 + endpoint_len + 2 + headers_len + 4 + item.payload.len(); // Check max message size if let Some(max_size) = self.max_message_size @@ -274,6 +383,12 @@ impl Encoder for TcpRequestCodec { // Write endpoint path dst.put_slice(endpoint_bytes); + // Write headers length + dst.put_u16(headers_len as u16); + + // Write headers + dst.put_slice(&headers_json); + // Write payload length dst.put_u32(item.payload.len() as u32); diff --git a/lib/runtime/src/pipeline/network/egress/tcp_client.rs b/lib/runtime/src/pipeline/network/egress/tcp_client.rs index fdcfdced725..f784059ea8e 100644 --- a/lib/runtime/src/pipeline/network/egress/tcp_client.rs +++ b/lib/runtime/src/pipeline/network/egress/tcp_client.rs @@ -324,7 +324,8 @@ impl TcpConnection { // Encode request on caller's thread (hot path optimization) // This allows multiple concurrent callers to encode in parallel // rather than serializing through the writer task - let request_msg = TcpRequestMessage::new(endpoint_path, payload); + // Include all headers (especially trace headers) in the message + let request_msg = TcpRequestMessage::with_headers(endpoint_path, headers.clone(), payload); let encoded_data = request_msg.encode()?; // Create response channel @@ -657,7 +658,7 @@ mod tests { let (stream, _) = listener.accept().await.unwrap(); let (mut read_half, mut write_half) = tokio::io::split(stream); - // Read request + // Read path length and path let mut len_buf = [0u8; 2]; read_half.read_exact(&mut len_buf).await.unwrap(); let path_len = u16::from_be_bytes(len_buf) as usize; @@ -665,6 +666,15 @@ mod tests { let mut path_buf = vec![0u8; path_len]; read_half.read_exact(&mut path_buf).await.unwrap(); + // Read headers length and headers + let mut headers_len_buf = [0u8; 2]; + read_half.read_exact(&mut headers_len_buf).await.unwrap(); + let headers_len = u16::from_be_bytes(headers_len_buf) as usize; + + let mut headers_buf = vec![0u8; headers_len]; + read_half.read_exact(&mut headers_buf).await.unwrap(); + + // Read payload length and payload let mut len_buf = [0u8; 4]; read_half.read_exact(&mut len_buf).await.unwrap(); let payload_len = u32::from_be_bytes(len_buf) as usize; @@ -728,6 +738,17 @@ mod tests { break; } + let mut headers_len_buf = [0u8; 2]; + if read_half.read_exact(&mut headers_len_buf).await.is_err() { + break; + } + let headers_len = u16::from_be_bytes(headers_len_buf) as usize; + + let mut headers_buf = vec![0u8; headers_len]; + if read_half.read_exact(&mut headers_buf).await.is_err() { + break; + } + let mut len_buf = [0u8; 4]; if read_half.read_exact(&mut len_buf).await.is_err() { break; @@ -826,6 +847,17 @@ mod tests { break; } + let mut headers_len_buf = [0u8; 2]; + if read_half.read_exact(&mut headers_len_buf).await.is_err() { + break; + } + let headers_len = u16::from_be_bytes(headers_len_buf) as usize; + + let mut headers_buf = vec![0u8; headers_len]; + if read_half.read_exact(&mut headers_buf).await.is_err() { + break; + } + let mut len_buf = [0u8; 4]; if read_half.read_exact(&mut len_buf).await.is_err() { break; diff --git a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs index f900155c1ac..0a4a7672c7b 100644 --- a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs +++ b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs @@ -266,6 +266,15 @@ impl SharedTcpServer { let mut path_buf = vec![0u8; path_len]; read_half.read_exact(&mut path_buf).await?; + // Read headers length (2 bytes) + let mut headers_len_buf = [0u8; 2]; + read_half.read_exact(&mut headers_len_buf).await?; + let headers_len = u16::from_be_bytes(headers_len_buf) as usize; + + // Read headers + let mut headers_buf = vec![0u8; headers_len]; + read_half.read_exact(&mut headers_buf).await?; + // Read payload length (4 bytes) let mut len_buf = [0u8; 4]; read_half.read_exact(&mut len_buf).await?; @@ -293,9 +302,12 @@ impl SharedTcpServer { read_half.read_exact(&mut payload_buf).await?; // Reconstruct the full message buffer for decoding using BytesMut - let mut full_msg = BytesMut::with_capacity(2 + path_len + 4 + payload_len); + let mut full_msg = + BytesMut::with_capacity(2 + path_len + 2 + headers_len + 4 + payload_len); full_msg.extend_from_slice(&path_len_buf); full_msg.extend_from_slice(&path_buf); + full_msg.extend_from_slice(&headers_len_buf); + full_msg.extend_from_slice(&headers_buf); full_msg.extend_from_slice(&len_buf); full_msg.extend_from_slice(&payload_buf); @@ -316,6 +328,7 @@ impl SharedTcpServer { }; let endpoint_path = request_msg.endpoint_path; + let headers = request_msg.headers; let payload = request_msg.payload; // Look up handler (lock-free read with DashMap) @@ -361,15 +374,18 @@ impl SharedTcpServer { tokio::spawn(async move { tracing::trace!(instance_id, "handling TCP request"); + // Create span with trace context from headers + let span = crate::logging::make_handle_payload_span_from_tcp_headers( + &headers, + &component_name, + &endpoint_name, + &namespace, + instance_id, + ); + let result = service_handler .handle_payload(payload) - .instrument(tracing::info_span!( - "handle_payload", - component = component_name.as_str(), - endpoint = endpoint_name.as_str(), - namespace = namespace.as_str(), - instance_id = instance_id, - )) + .instrument(span) .await; match result {