Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-router/benches/request_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
logit_bias: None,
user: None,
seed: None,
other: serde_json::Map::new(),
}
}

Expand Down
4 changes: 4 additions & 0 deletions sgl-router/src/openai_api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ pub struct CompletionRequest {
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,

/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>,
}

impl GenerationRequest for CompletionRequest {
Expand Down
90 changes: 75 additions & 15 deletions sgl-router/src/routers/pd_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,77 @@ impl PDRouter {
.await
}

// Route a completion request while preserving OpenAI format
pub async fn route_completion(
&self,
client: &reqwest::Client,
req: &HttpRequest,
mut typed_req: CompletionRequest,
route: &str,
) -> HttpResponse {
let start = Instant::now();

// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req.logprobs.is_some();

// Extract text for cache-aware routing from the typed request
let request_text = match &typed_req.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
};

// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair: {}", e);
RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e));
}
};

// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route,
prefill.url(),
decode.url()
);

// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e));
}

// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request: {}", e);
return HttpResponse::InternalServerError().body("Failed to serialize request");
}
};

// Execute dual dispatch
self.execute_dual_dispatch(
client,
req,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}

// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch(
Expand Down Expand Up @@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
match serde_json::from_value::<CompletionRequest>(body.clone()) {
match serde_json::from_value::<CompletionRequest>(body) {
Ok(openai_req) => {
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
let pd_req = openai_req.to_pd_request();
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(_) => {
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
match serde_json::from_value::<GenerateReqInput>(body) {
Ok(pd_req) => {
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(e) => {
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
}
}
// Use the new method that preserves OpenAI format
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
}
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
}
}

Expand Down
233 changes: 233 additions & 0 deletions sgl-router/src/routers/pd_types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Essential PDLB types extracted for PD routing

use crate::core::{Worker, WorkerType};
use crate::openai_api_types::{CompletionRequest, StringOrArray};
use serde::{Deserialize, Serialize};
use serde_json::Value;

Expand Down Expand Up @@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
self.bootstrap_room = Some(bootstrap_room);
}
}

// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl Bootstrap for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}

fn get_batch_size(&self) -> Result<Option<usize>, String> {
if let StringOrArray::Array(prompts) = &self.prompt {
if prompts.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(prompts.len()));
}

// Single string prompt
Ok(None)
}

fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
self.other.insert("bootstrap_host".to_string(), host_value);
}

// Insert bootstrap_port - it serializes correctly whether Single or Batch
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
self.other.insert("bootstrap_port".to_string(), port_value);
}

// Insert bootstrap_room - it serializes correctly whether Single or Batch
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
self.other.insert("bootstrap_room".to_string(), room_value);
}
}
}

#[cfg(test)]
mod bootstrap_tests {
use super::*;
use crate::openai_api_types::StringOrArray;

#[test]
fn test_completion_batch_size_with_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};

// Should return batch size for array prompt
assert_eq!(req.get_batch_size().unwrap(), Some(2));
}

#[test]
fn test_completion_batch_size_with_single_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};

// Should return None for single prompt
assert_eq!(req.get_batch_size().unwrap(), None);
}

#[test]
fn test_completion_batch_size_with_n_parameter() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: Some(3),
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};

// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!(req.get_batch_size().unwrap(), None);
}

#[test]
fn test_completion_bootstrap_single_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};

// Set bootstrap info - should always use single values
req.set_bootstrap_info(
BootstrapHost::Single("test-server".to_string()),
BootstrapPort::Single(Some(5678)),
BootstrapRoom::Single(12345),
);

// Verify single values were created
assert!(req.other.get("bootstrap_host").unwrap().is_string());
assert!(req.other.get("bootstrap_port").unwrap().is_number());
assert!(req.other.get("bootstrap_room").unwrap().is_number());

assert_eq!(
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
"test-server"
);
assert_eq!(
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
5678
);
assert_eq!(
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
12345
);
}

#[test]
fn test_completion_bootstrap_array_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
};

// Set bootstrap info with arrays
req.set_bootstrap_info(
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
BootstrapPort::Batch(vec![Some(5678); 2]),
BootstrapRoom::Batch(vec![12345, 67890]),
);

// Verify arrays were created correctly
assert!(req.other.get("bootstrap_host").unwrap().is_array());
assert!(req.other.get("bootstrap_port").unwrap().is_array());
assert!(req.other.get("bootstrap_room").unwrap().is_array());

let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
assert_eq!(hosts.len(), 2);
assert_eq!(hosts[0].as_str().unwrap(), "test-server");

let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
assert_eq!(ports.len(), 2);
assert_eq!(ports[0].as_u64().unwrap(), 5678);

let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
assert_eq!(rooms.len(), 2);
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
}
}
Loading
Loading