diff --git a/sgl-model-gateway/src/grpc_client/sglang_scheduler.rs b/sgl-model-gateway/src/grpc_client/sglang_scheduler.rs index 579f094eacf5..48db56d9c279 100644 --- a/sgl-model-gateway/src/grpc_client/sglang_scheduler.rs +++ b/sgl-model-gateway/src/grpc_client/sglang_scheduler.rs @@ -12,12 +12,15 @@ use std::{ use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; -use crate::protocols::{ - chat::ChatCompletionRequest, - common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, - generate::GenerateRequest, - responses::ResponsesRequest, - sampling_params::SamplingParams as GenerateSamplingParams, +use crate::{ + observability::otel_trace::inject_trace_context_grpc, + protocols::{ + chat::ChatCompletionRequest, + common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, + generate::GenerateRequest, + responses::ResponsesRequest, + sampling_params::SamplingParams as GenerateSamplingParams, + }, }; // Include the generated protobuf code @@ -163,7 +166,11 @@ impl SglangSchedulerClient { ) -> Result> { let request_id = req.request_id.clone(); let mut client = self.client.clone(); - let request = Request::new(req); + let mut request = Request::new(req); + + // Inject W3C trace context into gRPC metadata for distributed tracing + inject_trace_context_grpc(request.metadata_mut()); + let response = client.generate(request).await?; Ok(AbortOnDropStream::new( diff --git a/sgl-model-gateway/src/grpc_client/vllm_engine.rs b/sgl-model-gateway/src/grpc_client/vllm_engine.rs index 5a95bc1111ed..986682f8bed2 100644 --- a/sgl-model-gateway/src/grpc_client/vllm_engine.rs +++ b/sgl-model-gateway/src/grpc_client/vllm_engine.rs @@ -12,12 +12,15 @@ use std::{ use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; -use crate::protocols::{ - chat::ChatCompletionRequest, - common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, - generate::GenerateRequest, - responses::ResponsesRequest, - sampling_params::SamplingParams as GenerateSamplingParams, +use crate::{ + observability::otel_trace::inject_trace_context_grpc, + protocols::{ + chat::ChatCompletionRequest, + common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}, + generate::GenerateRequest, + responses::ResponsesRequest, + sampling_params::SamplingParams as GenerateSamplingParams, + }, }; // Include the generated protobuf code @@ -163,7 +166,11 @@ impl VllmEngineClient { ) -> Result> { let request_id = req.request_id.clone(); let mut client = self.client.clone(); - let request = Request::new(req); + let mut request = Request::new(req); + + // Inject W3C trace context into gRPC metadata for distributed tracing + inject_trace_context_grpc(request.metadata_mut()); + let response = client.generate(request).await?; Ok(AbortOnDropStream::new( diff --git a/sgl-model-gateway/src/observability/otel_trace.rs b/sgl-model-gateway/src/observability/otel_trace.rs index 1782decdfec0..7e0e2c9c288f 100644 --- a/sgl-model-gateway/src/observability/otel_trace.rs +++ b/sgl-model-gateway/src/observability/otel_trace.rs @@ -17,6 +17,7 @@ use opentelemetry_sdk::{ Resource, }; use tokio::task::spawn_blocking; +use tonic::metadata::{MetadataKey, MetadataMap, MetadataValue}; use tracing::{Metadata, Subscriber}; use tracing_opentelemetry::{self, OpenTelemetrySpanExt}; use tracing_subscriber::{ @@ -240,3 +241,32 @@ pub fn inject_trace_context_http(headers: &mut HeaderMap) { propagator.inject_context(&context, &mut HeaderInjector(headers)); }); } + +/// Inject W3C trace context into gRPC metadata. +/// +/// This propagates the current span context to downstream gRPC services. +/// Does nothing if OTEL is not enabled. +pub fn inject_trace_context_grpc(metadata: &mut MetadataMap) { + if !is_otel_enabled() { + return; + } + + let context = tracing::Span::current().context(); + + struct MetadataInjector<'a>(&'a mut MetadataMap); + + impl opentelemetry::propagation::Injector for MetadataInjector<'_> { + fn set(&mut self, key: &str, value: String) { + // gRPC metadata keys must be lowercase ASCII + if let Ok(metadata_key) = MetadataKey::from_bytes(key.to_lowercase().as_bytes()) { + if let Ok(metadata_value) = MetadataValue::try_from(&value) { + self.0.insert(metadata_key, metadata_value); + } + } + } + } + + global::get_text_map_propagator(|propagator| { + propagator.inject_context(&context, &mut MetadataInjector(metadata)); + }); +} diff --git a/sgl-model-gateway/src/policies/tree.rs b/sgl-model-gateway/src/policies/tree.rs index 6857d75b16ec..f8c67e1869df 100644 --- a/sgl-model-gateway/src/policies/tree.rs +++ b/sgl-model-gateway/src/policies/tree.rs @@ -590,30 +590,6 @@ impl Tree { .collect() } - pub fn get_smallest_tenant(&self) -> String { - // Return a placeholder if there are no tenants - if self.tenant_char_count.is_empty() { - return "empty".to_string(); - } - - // Find the tenant with minimum char count - let mut min_tenant = None; - let mut min_count = usize::MAX; - - for entry in self.tenant_char_count.iter() { - let tenant = entry.key(); - let count = *entry.value(); - - if count < min_count { - min_count = count; - min_tenant = Some(tenant.clone()); - } - } - - // Return the found tenant or "empty" if somehow none was found - min_tenant.unwrap_or_else(|| "empty".to_string()) - } - #[allow(dead_code)] pub fn get_used_size_per_tenant(&self) -> HashMap { // perform a DFS to traverse all nodes and calculate the total size used by each tenant @@ -728,54 +704,6 @@ mod tests { use super::*; - #[test] - fn test_get_smallest_tenant() { - let tree = Tree::new(); - - assert_eq!(tree.get_smallest_tenant(), "empty"); - - // Insert data for tenant1 - "ap" + "icot" = 6 chars - tree.insert("ap", "tenant1"); - tree.insert("icot", "tenant1"); - - // Insert data for tenant2 - "cat" = 3 chars - tree.insert("cat", "tenant2"); - - assert_eq!( - tree.get_smallest_tenant(), - "tenant2", - "Expected tenant2 to be smallest with 3 characters." - ); - - // Insert overlapping data for tenant3 and tenant4 to test equal counts - // tenant3: "do" = 2 chars - // tenant4: "hi" = 2 chars - tree.insert("do", "tenant3"); - tree.insert("hi", "tenant4"); - - let smallest = tree.get_smallest_tenant(); - assert!( - smallest == "tenant3" || smallest == "tenant4", - "Expected either tenant3 or tenant4 (both have 2 characters), got {}", - smallest - ); - - // Add more text to tenant4 to make it larger - tree.insert("hello", "tenant4"); // Now tenant4 has "hi" + "hello" = 6 chars - - // Now tenant3 should be smallest (2 chars vs 6 chars for tenant4) - assert_eq!( - tree.get_smallest_tenant(), - "tenant3", - "Expected tenant3 to be smallest with 2 characters" - ); - - tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars - - let post_eviction_smallest = tree.get_smallest_tenant(); - println!("Smallest tenant after eviction: {}", post_eviction_smallest); - } - #[test] fn test_tenant_char_count() { let tree = Tree::new(); diff --git a/sgl-model-gateway/src/routers/grpc/common/stages/request_execution.rs b/sgl-model-gateway/src/routers/grpc/common/stages/request_execution.rs index ed2b44a652d7..89898451ddd4 100644 --- a/sgl-model-gateway/src/routers/grpc/common/stages/request_execution.rs +++ b/sgl-model-gateway/src/routers/grpc/common/stages/request_execution.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info_span, Instrument}; use super::PipelineStage; use crate::routers::grpc::{ @@ -18,6 +18,7 @@ pub struct RequestExecutionStage { mode: ExecutionMode, } +#[derive(Debug, Clone, Copy)] pub enum ExecutionMode { /// Regular mode: single worker execution Single, @@ -50,12 +51,39 @@ impl PipelineStage for RequestExecutionStage { error::internal_error("Client acquisition not completed") })?; - let result = match self.mode { - ExecutionMode::Single => self.execute_single(proto_request, clients).await?, - ExecutionMode::DualDispatch => { - self.execute_dual_dispatch(proto_request, clients).await? + // Extract dispatch metadata for tracing span + let request_id = ctx + .state + .dispatch + .as_ref() + .map(|d| d.request_id.as_str()) + .unwrap_or("unknown"); + let model = ctx + .state + .dispatch + .as_ref() + .map(|d| d.model.as_str()) + .unwrap_or("unknown"); + + // Create OTEL span for gRPC request execution + let span = info_span!( + target: "sgl_model_gateway::otel-trace", + "grpc_generate", + request_id = %request_id, + model = %model, + mode = ?self.mode, + ); + + let result = async { + match self.mode { + ExecutionMode::Single => self.execute_single(proto_request, clients).await, + ExecutionMode::DualDispatch => { + self.execute_dual_dispatch(proto_request, clients).await + } } - }; + } + .instrument(span) + .await?; // Store result in context for ResponseProcessingStage ctx.state.response.execution_result = Some(result); diff --git a/sgl-model-gateway/tests/otel_tracing_test.rs b/sgl-model-gateway/tests/otel_tracing_test.rs index fc2c165906e1..6bb616c9f3ab 100644 --- a/sgl-model-gateway/tests/otel_tracing_test.rs +++ b/sgl-model-gateway/tests/otel_tracing_test.rs @@ -24,9 +24,11 @@ use sgl_model_gateway::{ routers::RouterFactory, }; use tokio::sync::oneshot; +use tonic::metadata::MetadataMap; use tonic_v12::{transport::Server, Request as TonicRequest, Response, Status}; use tower::ServiceExt; use tracing::info_span; +use tracing_subscriber::prelude::*; #[derive(Clone)] struct TestOtelCollector { @@ -131,14 +133,22 @@ async fn test_router_with_tracing() { .enable_trace(&collector_endpoint) .build_unchecked(); - // 4. Initialize the OTLP client - let init_result = otel_trace::otel_tracing_init(true, Some(&collector_endpoint)); - assert!( - init_result.is_ok(), - "Failed to initialize OTEL: {:?}", - init_result.err() - ); - println!("OpenTelemetry initialized successfully"); + // 4. Initialize the OTLP client (check if already initialized by another test) + let otel_initialized_by_this_test = if !otel_trace::is_otel_enabled() { + let init_result = otel_trace::otel_tracing_init(true, Some(&collector_endpoint)); + assert!( + init_result.is_ok(), + "Failed to initialize OTEL: {:?}", + init_result.err() + ); + println!("OpenTelemetry initialized successfully"); + true + } else { + println!( + "OpenTelemetry already initialized by previous test (spans will go to that collector)" + ); + false + }; let trace_config = TraceConfig { enable_trace: true, @@ -232,14 +242,23 @@ async fn test_router_with_tracing() { let span_count = collector.get_span_count(); println!("Total spans received by collector: {}", span_count); - assert!( - span_count == 2, - "Expected to receive at least 2 span, but got {}. \ - This indicates that tracing data is not being exported to the OTLP collector.", - span_count - ); - - println!("Test passed! Collector received {} spans", span_count); + // Only assert span count if we initialized OTEL with our own collector + // When OTEL was pre-initialized by another test, spans go to that collector instead + if otel_initialized_by_this_test { + assert!( + span_count == 2, + "Expected to receive at least 2 span, but got {}. \ + This indicates that tracing data is not being exported to the OTLP collector.", + span_count + ); + println!("Test passed! Collector received {} spans", span_count); + } else { + println!( + "Skipping span count assertion - OTEL was pre-initialized by another test. \ + Spans went to that collector. Received {} spans on this test's collector.", + span_count + ); + } // 13. cleanup let _ = shutdown_tx.send(()); @@ -247,3 +266,115 @@ async fn test_router_with_tracing() { println!("Cleanup completed"); } + +// ============================================================================ +// gRPC Trace Context Injection Tests +// ============================================================================ + +/// Comprehensive test for gRPC trace context injection. +/// +/// This test validates: +/// 1. W3C trace context headers are properly injected into gRPC metadata +/// 2. traceparent format is correct (version-traceid-spanid-flags) +/// 3. All metadata keys are lowercase (gRPC requirement) +/// +/// Note: This test handles the case where OTEL may already be initialized +/// by a previous test (since tests run sequentially with #[serial]). +#[tokio::test] +#[serial] +async fn test_grpc_trace_context_injection() { + // 1. Start the OTLP collector (needed even if OTEL is already initialized, + // as a target for any spans that might be exported) + let port = pick_unused_port().expect("Failed to pick unused port"); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let _collector = start_collector(port, shutdown_rx) + .await + .expect("Failed to start collector"); + let collector_endpoint = format!("0.0.0.0:{}", port); + + // 2. Initialize OTEL if not already enabled + // Note: otel_tracing_init will fail if already initialized (OnceLock), + // but that's fine - we just need OTEL to be enabled + let already_enabled = otel_trace::is_otel_enabled(); + if !already_enabled { + let init_result = otel_trace::otel_tracing_init(true, Some(&collector_endpoint)); + assert!( + init_result.is_ok(), + "Failed to initialize OTEL: {:?}", + init_result.err() + ); + } + + // Verify OTEL is enabled (either from this test or a previous one) + assert!(otel_trace::is_otel_enabled(), "OTEL should be enabled"); + + // 3. Set up tracing subscriber with OTEL layer + let otel_layer = otel_trace::get_otel_layer().expect("Failed to get OTEL layer"); + let subscriber = tracing_subscriber::registry().with(otel_layer); + + // 4. Test within a span context + tracing::subscriber::with_default(subscriber, || { + // Create a span that will be exported to OTEL + let span = info_span!(target: "sgl_model_gateway::otel-trace", "test_grpc_span"); + let _guard = span.enter(); + + // Create empty gRPC metadata + let mut metadata = MetadataMap::new(); + + // Inject trace context + otel_trace::inject_trace_context_grpc(&mut metadata); + + // === Test 1: Verify traceparent header was injected === + let traceparent = metadata.get("traceparent"); + assert!( + traceparent.is_some(), + "traceparent header should be present in gRPC metadata" + ); + + // === Test 2: Verify traceparent format (version-traceid-spanid-flags) === + let traceparent_value = traceparent.unwrap().to_str().unwrap(); + let parts: Vec<&str> = traceparent_value.split('-').collect(); + assert_eq!( + parts.len(), + 4, + "traceparent should have 4 parts: version-traceid-spanid-flags" + ); + assert_eq!(parts[0], "00", "traceparent version should be 00"); + assert_eq!(parts[1].len(), 32, "trace ID should be 32 hex characters"); + assert_eq!(parts[2].len(), 16, "span ID should be 16 hex characters"); + + println!("Successfully injected traceparent: {}", traceparent_value); + + // === Test 3: Verify all keys are lowercase (gRPC metadata requirement) === + for key_and_value in metadata.iter() { + match key_and_value { + tonic::metadata::KeyAndValueRef::Ascii(key, _) => { + let key_str = key.as_str(); + assert_eq!( + key_str, + key_str.to_lowercase(), + "gRPC metadata key '{}' should be lowercase", + key_str + ); + } + tonic::metadata::KeyAndValueRef::Binary(key, _) => { + let key_str = key.as_str(); + assert_eq!( + key_str, + key_str.to_lowercase(), + "gRPC metadata key '{}' should be lowercase", + key_str + ); + } + } + } + + println!("All gRPC metadata keys are lowercase as required"); + }); + + // Cleanup - don't shutdown OTEL since tests share global state (OnceLock) + // and other tests may need to use the already-initialized OTEL + let _ = shutdown_tx.send(()); + + println!("test_grpc_trace_context_injection: All assertions passed!"); +}