diff --git a/sgl-model-gateway/src/policies/tree.rs b/sgl-model-gateway/src/policies/tree.rs index 4ad10fff1f27..6857d75b16ec 100644 --- a/sgl-model-gateway/src/policies/tree.rs +++ b/sgl-model-gateway/src/policies/tree.rs @@ -10,6 +10,36 @@ use tracing::info; type NodeRef = Arc; +/// Pre-indexed text for efficient character access. +/// Converts UTF-8 string to Vec once to enable O(1) indexing. +struct CharIndexedText { + chars: Vec, +} + +impl CharIndexedText { + #[inline] + fn new(text: &str) -> Self { + Self { + chars: text.chars().collect(), + } + } + + #[inline] + fn len(&self) -> usize { + self.chars.len() + } + + #[inline] + fn get(&self, idx: usize) -> Option { + self.chars.get(idx).copied() + } + + #[inline] + fn slice_to_string(&self, start: usize, end: usize) -> String { + self.chars[start..end].iter().collect() + } +} + #[derive(Debug)] struct Node { children: DashMap, @@ -57,13 +87,14 @@ impl PartialEq for EvictionEntry { // Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes. // https://en.wikipedia.org/wiki/UTF-8 -fn shared_prefix_count(a: &str, b: &str) -> usize { +/// Efficient shared prefix count using pre-indexed chars for O(1) access +#[inline] +fn shared_prefix_count_indexed(a: &CharIndexedText, a_start: usize, b: &str) -> usize { let mut i = 0; - let mut a_iter = a.chars(); let mut b_iter = b.chars(); - loop { - match (a_iter.next(), b_iter.next()) { + while a_start + i < a.len() { + match (a.get(a_start + i), b_iter.next()) { (Some(a_char), Some(b_char)) if a_char == b_char => { i += 1; } @@ -107,6 +138,9 @@ impl Tree { pub fn insert(&self, text: &str, tenant: &str) { // Insert text into tree with given tenant + // Pre-index text once for O(1) character access (avoids O(n²) chars().nth() calls) + let indexed_text = CharIndexedText::new(text); + let text_count = indexed_text.len(); let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; @@ -125,10 +159,9 @@ impl Tree { let mut prev = Arc::clone(&self.root); - let text_count = text.chars().count(); - while curr_idx < text_count { - let first_char = text.chars().nth(curr_idx).unwrap(); + // O(1) character access instead of O(n) chars().nth() + let first_char = indexed_text.get(curr_idx).unwrap(); curr = prev; @@ -146,8 +179,9 @@ impl Tree { [curr] => [new node] */ - let curr_text = slice_by_chars(text, curr_idx, text_count); - let curr_text_count = curr_text.chars().count(); + // Use indexed slice for efficient string extraction + let curr_text = indexed_text.slice_to_string(curr_idx, text_count); + let curr_text_count = text_count - curr_idx; let new_node = Arc::new(Node { children: DashMap::new(), text: RwLock::new(curr_text), @@ -174,11 +208,12 @@ impl Tree { // matched let matched_node = entry.get().clone(); - let matched_node_text = matched_node.text.read().unwrap().to_owned(); + let matched_node_text = matched_node.text.read().unwrap(); let matched_node_text_count = matched_node_text.chars().count(); - let curr_text = slice_by_chars(text, curr_idx, text_count); - let shared_count = shared_prefix_count(&matched_node_text, &curr_text); + // Use indexed comparison to avoid creating intermediate string + let shared_count = + shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_node_text); if shared_count < matched_node_text_count { /* @@ -194,7 +229,10 @@ impl Tree { shared_count, matched_node_text_count, ); - let matched_text_count = matched_text.chars().count(); + let matched_text_count = shared_count; + + // Drop read lock before creating new node + drop(matched_node_text); let new_node = Arc::new(Node { text: RwLock::new(matched_text), @@ -203,7 +241,7 @@ impl Tree { tenant_last_access_time: matched_node.tenant_last_access_time.clone(), }); - let first_new_char = contracted_text.chars().nth(0).unwrap(); + let first_new_char = contracted_text.chars().next().unwrap(); new_node .children .insert(first_new_char, Arc::clone(&matched_node)); @@ -232,6 +270,9 @@ impl Tree { curr_idx += shared_count; } else { // move to next node + // Drop read lock before continuing + drop(matched_node_text); + prev = Arc::clone(&matched_node); // Atomically attach tenant to existing node and increment count once @@ -256,22 +297,27 @@ impl Tree { #[allow(unused_assignments)] pub fn prefix_match(&self, text: &str) -> (String, String) { + // Pre-index text once for O(1) character access + let indexed_text = CharIndexedText::new(text); + let text_count = indexed_text.len(); + let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; let mut prev = Arc::clone(&self.root); - let text_count = text.chars().count(); while curr_idx < text_count { - let first_char = text.chars().nth(curr_idx).unwrap(); - let curr_text = slice_by_chars(text, curr_idx, text_count); + // O(1) character access instead of O(n) chars().nth() + let first_char = indexed_text.get(curr_idx).unwrap(); curr = prev.clone(); if let Some(entry) = curr.children.get(&first_char) { let matched_node = entry.value().clone(); let matched_text_guard = matched_node.text.read().unwrap(); - let shared_count = shared_prefix_count(&matched_text_guard, &curr_text); + // Use indexed comparison to avoid creating intermediate string + let shared_count = + shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_text_guard); let matched_node_text_count = matched_text_guard.chars().count(); drop(matched_text_guard); @@ -299,7 +345,7 @@ impl Tree { .iter() .next() .map(|kv| kv.key().to_owned()) - .unwrap_or("empty".to_string()); + .unwrap_or_else(|| "empty".to_string()); // Traverse from the curr node to the root and update the timestamp @@ -308,7 +354,7 @@ impl Tree { .unwrap() .as_millis(); - if !tenant.eq("empty") { + if tenant != "empty" { let mut current_node = Some(curr); while let Some(node) = current_node { node.tenant_last_access_time @@ -317,21 +363,25 @@ impl Tree { } } - let ret_text = slice_by_chars(text, 0, curr_idx); + // Use indexed slice for result + let ret_text = indexed_text.slice_to_string(0, curr_idx); (ret_text, tenant) } #[allow(unused_assignments, dead_code)] pub fn prefix_match_tenant(&self, text: &str, tenant: &str) -> String { + // Pre-index text once for O(1) character access + let indexed_text = CharIndexedText::new(text); + let text_count = indexed_text.len(); + let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; let mut prev = Arc::clone(&self.root); - let text_count = text.chars().count(); while curr_idx < text_count { - let first_char = text.chars().nth(curr_idx).unwrap(); - let curr_text = slice_by_chars(text, curr_idx, text_count); + // O(1) character access instead of O(n) chars().nth() + let first_char = indexed_text.get(curr_idx).unwrap(); curr = prev.clone(); @@ -344,7 +394,9 @@ impl Tree { } let matched_text_guard = matched_node.text.read().unwrap(); - let shared_count = shared_prefix_count(&matched_text_guard, &curr_text); + // Use indexed comparison to avoid creating intermediate string + let shared_count = + shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_text_guard); let matched_node_text_count = matched_text_guard.chars().count(); drop(matched_text_guard); @@ -381,7 +433,8 @@ impl Tree { } } - slice_by_chars(text, 0, curr_idx) + // Use indexed slice for result + indexed_text.slice_to_string(0, curr_idx) } fn leaf_of(node: &NodeRef) -> Vec { diff --git a/sgl-model-gateway/src/protocols/chat.rs b/sgl-model-gateway/src/protocols/chat.rs index c1f54828e378..33dbf668b9da 100644 --- a/sgl-model-gateway/src/protocols/chat.rs +++ b/sgl-model-gateway/src/protocols/chat.rs @@ -69,19 +69,24 @@ impl MessageContent { /// Returns the text content, cloning only when necessary. /// For simple text, returns a clone of the string. /// For parts, concatenates text parts with spaces. + /// Optimized to avoid intermediate Vec allocation. pub fn to_simple_string(&self) -> String { match self { MessageContent::Text(text) => text.clone(), MessageContent::Parts(parts) => { - // Pre-count text parts to avoid intermediate Vec allocation - let text_parts: Vec<&str> = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.as_str()), - _ => None, - }) - .collect(); - text_parts.join(" ") + // Use fold to build string directly without intermediate Vec allocation + let mut result = String::new(); + let mut first = true; + for part in parts { + if let ContentPart::Text { text } = part { + if !first { + result.push(' '); + } + result.push_str(text); + first = false; + } + } + result } } } diff --git a/sgl-model-gateway/src/routers/http/pd_router.rs b/sgl-model-gateway/src/routers/http/pd_router.rs index 68926cf7afbf..dcaede54ebad 100644 --- a/sgl-model-gateway/src/routers/http/pd_router.rs +++ b/sgl-model-gateway/src/routers/http/pd_router.rs @@ -189,6 +189,11 @@ impl PDRouter { None } + // Static key strings to avoid per-request allocations + const BOOTSTRAP_HOST_KEY: &'static str = "bootstrap_host"; + const BOOTSTRAP_PORT_KEY: &'static str = "bootstrap_port"; + const BOOTSTRAP_ROOM_KEY: &'static str = "bootstrap_room"; + fn inject_bootstrap_into_value( mut original: Value, prefill_worker: &dyn Worker, @@ -207,12 +212,13 @@ impl PDRouter { ports.push(prefill_worker.bootstrap_port()); rooms.push(super::pd_types::generate_room_id()); } + // Use static string keys to avoid per-request allocations obj.insert( - "bootstrap_host".to_string(), + Self::BOOTSTRAP_HOST_KEY.to_string(), Value::Array(hosts.into_iter().map(Value::from).collect()), ); obj.insert( - "bootstrap_port".to_string(), + Self::BOOTSTRAP_PORT_KEY.to_string(), Value::Array( ports .into_iter() @@ -224,23 +230,24 @@ impl PDRouter { ), ); obj.insert( - "bootstrap_room".to_string(), + Self::BOOTSTRAP_ROOM_KEY.to_string(), Value::Array(rooms.into_iter().map(Value::from).collect()), ); } else { + // Use static string keys to avoid per-request allocations obj.insert( - "bootstrap_host".to_string(), + Self::BOOTSTRAP_HOST_KEY.to_string(), Value::from(prefill_worker.bootstrap_host()), ); obj.insert( - "bootstrap_port".to_string(), + Self::BOOTSTRAP_PORT_KEY.to_string(), match prefill_worker.bootstrap_port() { Some(v) => Value::from(v), None => Value::Null, }, ); obj.insert( - "bootstrap_room".to_string(), + Self::BOOTSTRAP_ROOM_KEY.to_string(), Value::from(super::pd_types::generate_room_id()), ); } @@ -256,12 +263,15 @@ impl PDRouter { let start_time = Instant::now(); let route = context.route; + // Clone request once outside the retry loop, then use Arc to share across attempts + // This avoids O(retries) clones by sharing the same data + let shared_request = Arc::new(original_request.clone()); RetryExecutor::execute_response_with_retry( &self.retry_config, { - let original_request = original_request.clone(); move |attempt: u32| { - let original_request = original_request.clone(); + // Clone Arc (cheap reference count increment) instead of cloning the entire request + let shared_request = Arc::clone(&shared_request); let context = context.clone(); async move { let (prefill, decode) = match self @@ -282,7 +292,7 @@ impl PDRouter { decode.url() ); - let mut json_request = match serde_json::to_value(&original_request) { + let mut json_request = match serde_json::to_value(shared_request.as_ref()) { Ok(v) => v, Err(e) => return Self::handle_serialization_error(e), }; @@ -899,6 +909,7 @@ impl PDRouter { } // Helper to merge logprobs from prefill and decode responses + // Optimized to avoid double cloning by taking ownership of decode array fn merge_logprobs_in_json(prefill_json: &Value, decode_json: &mut Value) -> bool { if let (Some(prefill_meta), Some(decode_meta)) = ( prefill_json.get("meta_info"), @@ -908,13 +919,17 @@ impl PDRouter { prefill_meta.get("input_token_logprobs"), decode_meta.get_mut("input_token_logprobs"), ) { - if let (Some(prefill_arr), Some(decode_arr)) = - (prefill_logprobs.as_array(), decode_logprobs.as_array_mut()) - { - let mut merged = prefill_arr.clone(); - merged.extend(decode_arr.clone()); - decode_meta["input_token_logprobs"] = Value::Array(merged); - return true; + if let Some(prefill_arr) = prefill_logprobs.as_array() { + // Take ownership of decode array to avoid cloning it + let decode_arr = std::mem::take(decode_logprobs); + if let Value::Array(decode_vec) = decode_arr { + // Pre-allocate merged array with exact capacity + let mut merged = Vec::with_capacity(prefill_arr.len() + decode_vec.len()); + merged.extend(prefill_arr.iter().cloned()); + merged.extend(decode_vec); + decode_meta["input_token_logprobs"] = Value::Array(merged); + return true; + } } } } @@ -922,6 +937,7 @@ impl PDRouter { } // Simple helper to merge logprobs in streaming responses + // Optimized to reduce allocations in the merge path fn merge_streaming_logprobs( prefill_logprobs: Option, decode_chunk: &[u8], @@ -940,12 +956,16 @@ impl PDRouter { if let Some(ref p_logprobs) = prefill_logprobs { if let Some(meta) = decode_json.get_mut("meta_info") { if let Some(d_logprobs) = meta.get_mut("input_token_logprobs") { - if let (Some(p_arr), Some(d_arr)) = - (p_logprobs.as_array(), d_logprobs.as_array()) - { - let mut merged = p_arr.clone(); - merged.extend(d_arr.clone()); - *d_logprobs = Value::Array(merged); + if let Some(p_arr) = p_logprobs.as_array() { + // Take ownership of decode array to avoid cloning it + let decode_arr = std::mem::take(d_logprobs); + if let Value::Array(d_vec) = decode_arr { + // Pre-allocate merged array with exact capacity + let mut merged = Vec::with_capacity(p_arr.len() + d_vec.len()); + merged.extend(p_arr.iter().cloned()); + merged.extend(d_vec); + *d_logprobs = Value::Array(merged); + } } } } diff --git a/sgl-model-gateway/src/routers/http/router.rs b/sgl-model-gateway/src/routers/http/router.rs index b7b6ee8c7696..7aad2c3f7229 100644 --- a/sgl-model-gateway/src/routers/http/router.rs +++ b/sgl-model-gateway/src/routers/http/router.rs @@ -334,8 +334,11 @@ impl Router { }; if let Some(api_key) = worker.api_key() { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", api_key)); + // Pre-allocate string with capacity to avoid reallocation + let mut auth_header = String::with_capacity(7 + api_key.len()); + auth_header.push_str("Bearer "); + auth_header.push_str(api_key); + request_builder = request_builder.header("Authorization", auth_header); } // Apply pre-filtered headers @@ -432,6 +435,9 @@ impl Router { let worker = self.worker_registry.get_by_url(worker_url); let api_key = worker.as_ref().and_then(|w| w.api_key().clone()); + // Static key string to avoid per-request allocations + const DP_RANK_KEY: &str = "data_parallel_rank"; + let mut request_builder = if self.dp_aware { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, @@ -457,10 +463,8 @@ impl Router { }; if let Some(map) = json_val.as_object_mut() { - map.insert( - String::from("data_parallel_rank"), - serde_json::json!(dp_rank), - ); + // Use static key string to avoid allocation + map.insert(DP_RANK_KEY.to_string(), serde_json::json!(dp_rank)); // Only serialize if debug logging is enabled to avoid CPU overhead if tracing::enabled!(tracing::Level::DEBUG) { debug!( @@ -486,7 +490,11 @@ impl Router { }; if let Some(key) = api_key { - request_builder = request_builder.header("Authorization", format!("Bearer {}", key)); + // Pre-allocate string with capacity to avoid reallocation + let mut auth_header = String::with_capacity(7 + key.len()); + auth_header.push_str("Bearer "); + auth_header.push_str(&key); + request_builder = request_builder.header("Authorization", auth_header); } // Copy all headers from original request if provided