Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions sgl-model-gateway/src/grpc_client/sglang_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ use std::{
use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn};

use crate::protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
use crate::{
observability::otel_trace::inject_trace_context_grpc,
protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
},
};

// Include the generated protobuf code
Expand Down Expand Up @@ -163,7 +166,11 @@ impl SglangSchedulerClient {
) -> Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>> {
let request_id = req.request_id.clone();
let mut client = self.client.clone();
let request = Request::new(req);
let mut request = Request::new(req);

// Inject W3C trace context into gRPC metadata for distributed tracing
inject_trace_context_grpc(request.metadata_mut());

let response = client.generate(request).await?;

Ok(AbortOnDropStream::new(
Expand Down
21 changes: 14 additions & 7 deletions sgl-model-gateway/src/grpc_client/vllm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ use std::{
use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn};

use crate::protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
use crate::{
observability::otel_trace::inject_trace_context_grpc,
protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
},
};

// Include the generated protobuf code
Expand Down Expand Up @@ -163,7 +166,11 @@ impl VllmEngineClient {
) -> Result<AbortOnDropStream, Box<dyn std::error::Error + Send + Sync>> {
let request_id = req.request_id.clone();
let mut client = self.client.clone();
let request = Request::new(req);
let mut request = Request::new(req);

// Inject W3C trace context into gRPC metadata for distributed tracing
inject_trace_context_grpc(request.metadata_mut());

let response = client.generate(request).await?;

Ok(AbortOnDropStream::new(
Expand Down
30 changes: 30 additions & 0 deletions sgl-model-gateway/src/observability/otel_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use opentelemetry_sdk::{
Resource,
};
use tokio::task::spawn_blocking;
use tonic::metadata::{MetadataKey, MetadataMap, MetadataValue};
use tracing::{Metadata, Subscriber};
use tracing_opentelemetry::{self, OpenTelemetrySpanExt};
use tracing_subscriber::{
Expand Down Expand Up @@ -240,3 +241,32 @@ pub fn inject_trace_context_http(headers: &mut HeaderMap) {
propagator.inject_context(&context, &mut HeaderInjector(headers));
});
}

/// Inject W3C trace context into gRPC metadata.
///
/// This propagates the current span context to downstream gRPC services.
/// Does nothing if OTEL is not enabled.
pub fn inject_trace_context_grpc(metadata: &mut MetadataMap) {
if !is_otel_enabled() {
return;
}

let context = tracing::Span::current().context();

struct MetadataInjector<'a>(&'a mut MetadataMap);

impl opentelemetry::propagation::Injector for MetadataInjector<'_> {
fn set(&mut self, key: &str, value: String) {
// gRPC metadata keys must be lowercase ASCII
if let Ok(metadata_key) = MetadataKey::from_bytes(key.to_lowercase().as_bytes()) {
if let Ok(metadata_value) = MetadataValue::try_from(&value) {
self.0.insert(metadata_key, metadata_value);
}
}
}
}

global::get_text_map_propagator(|propagator| {
propagator.inject_context(&context, &mut MetadataInjector(metadata));
});
}
72 changes: 0 additions & 72 deletions sgl-model-gateway/src/policies/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,30 +590,6 @@ impl Tree {
.collect()
}

pub fn get_smallest_tenant(&self) -> String {
// Return a placeholder if there are no tenants
if self.tenant_char_count.is_empty() {
return "empty".to_string();
}

// Find the tenant with minimum char count
let mut min_tenant = None;
let mut min_count = usize::MAX;

for entry in self.tenant_char_count.iter() {
let tenant = entry.key();
let count = *entry.value();

if count < min_count {
min_count = count;
min_tenant = Some(tenant.clone());
}
}

// Return the found tenant or "empty" if somehow none was found
min_tenant.unwrap_or_else(|| "empty".to_string())
}

#[allow(dead_code)]
pub fn get_used_size_per_tenant(&self) -> HashMap<String, usize> {
// perform a DFS to traverse all nodes and calculate the total size used by each tenant
Expand Down Expand Up @@ -728,54 +704,6 @@ mod tests {

use super::*;

#[test]
fn test_get_smallest_tenant() {
let tree = Tree::new();

assert_eq!(tree.get_smallest_tenant(), "empty");

// Insert data for tenant1 - "ap" + "icot" = 6 chars
tree.insert("ap", "tenant1");
tree.insert("icot", "tenant1");

// Insert data for tenant2 - "cat" = 3 chars
tree.insert("cat", "tenant2");

assert_eq!(
tree.get_smallest_tenant(),
"tenant2",
"Expected tenant2 to be smallest with 3 characters."
);

// Insert overlapping data for tenant3 and tenant4 to test equal counts
// tenant3: "do" = 2 chars
// tenant4: "hi" = 2 chars
tree.insert("do", "tenant3");
tree.insert("hi", "tenant4");

let smallest = tree.get_smallest_tenant();
assert!(
smallest == "tenant3" || smallest == "tenant4",
"Expected either tenant3 or tenant4 (both have 2 characters), got {}",
smallest
);

// Add more text to tenant4 to make it larger
tree.insert("hello", "tenant4"); // Now tenant4 has "hi" + "hello" = 6 chars

// Now tenant3 should be smallest (2 chars vs 6 chars for tenant4)
assert_eq!(
tree.get_smallest_tenant(),
"tenant3",
"Expected tenant3 to be smallest with 2 characters"
);

tree.evict_tenant_by_size(3); // This should evict tenants with more than 3 chars

let post_eviction_smallest = tree.get_smallest_tenant();
println!("Smallest tenant after eviction: {}", post_eviction_smallest);
}

#[test]
fn test_tenant_char_count() {
let tree = Tree::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use async_trait::async_trait;
use axum::response::Response;
use tracing::error;
use tracing::{error, info_span, Instrument};

use super::PipelineStage;
use crate::routers::grpc::{
Expand All @@ -18,6 +18,7 @@ pub struct RequestExecutionStage {
mode: ExecutionMode,
}

#[derive(Debug, Clone, Copy)]
pub enum ExecutionMode {
/// Regular mode: single worker execution
Single,
Expand Down Expand Up @@ -50,12 +51,39 @@ impl PipelineStage for RequestExecutionStage {
error::internal_error("Client acquisition not completed")
})?;

let result = match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await?,
ExecutionMode::DualDispatch => {
self.execute_dual_dispatch(proto_request, clients).await?
// Extract dispatch metadata for tracing span
let request_id = ctx
.state
.dispatch
.as_ref()
.map(|d| d.request_id.as_str())
.unwrap_or("unknown");
let model = ctx
.state
.dispatch
.as_ref()
.map(|d| d.model.as_str())
.unwrap_or("unknown");

// Create OTEL span for gRPC request execution
let span = info_span!(
target: "sgl_model_gateway::otel-trace",
"grpc_generate",
request_id = %request_id,
model = %model,
mode = ?self.mode,
);

let result = async {
match self.mode {
ExecutionMode::Single => self.execute_single(proto_request, clients).await,
ExecutionMode::DualDispatch => {
self.execute_dual_dispatch(proto_request, clients).await
}
}
};
}
.instrument(span)
.await?;

// Store result in context for ResponseProcessingStage
ctx.state.response.execution_result = Some(result);
Expand Down
Loading
Loading