diff --git a/sgl-model-gateway/Cargo.toml b/sgl-model-gateway/Cargo.toml index 01c7dbc1f462..14ddef6527bb 100644 --- a/sgl-model-gateway/Cargo.toml +++ b/sgl-model-gateway/Cargo.toml @@ -156,6 +156,11 @@ name = "tool_parser_benchmark" harness = false path = "benches/tool_parser_benchmark.rs" +[[bench]] +name = "tree_benchmark" +harness = false +path = "benches/tree_benchmark.rs" + [profile.release] opt-level = "z" # Optimize for size lto = "fat" # Full LTO for smaller binaries diff --git a/sgl-model-gateway/benches/tree_benchmark.rs b/sgl-model-gateway/benches/tree_benchmark.rs new file mode 100644 index 000000000000..2b1589cf0fc4 --- /dev/null +++ b/sgl-model-gateway/benches/tree_benchmark.rs @@ -0,0 +1,294 @@ +//! Benchmarks for the radix tree implementation used in cache-aware routing. +//! +//! Run with: cargo bench --bench tree_benchmark + +use std::{sync::Arc, thread}; + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use rand::{ + distr::{Alphanumeric, SampleString}, + rng as thread_rng, +}; +// Import the tree module +use sgl_model_gateway::policies::tree::Tree; + +/// Generate random ASCII strings of given length +fn random_ascii_string(len: usize) -> String { + Alphanumeric.sample_string(&mut thread_rng(), len) +} + +/// Generate random strings with common prefixes (simulates real request patterns) +fn random_prefixed_strings(prefix: &str, suffix_len: usize, count: usize) -> Vec { + (0..count) + .map(|_| format!("{}{}", prefix, random_ascii_string(suffix_len))) + .collect() +} + +/// Benchmark single-threaded insert throughput +fn bench_insert_throughput(c: &mut Criterion) { + let mut group = c.benchmark_group("insert_throughput"); + + for text_len in [10, 50, 100, 500].iter() { + group.throughput(Throughput::Elements(1)); + group.bench_with_input( + BenchmarkId::new("random_text", text_len), + text_len, + |b, &len| { + let tree = Tree::new(); + let strings: Vec = (0..1000).map(|_| random_ascii_string(len)).collect(); + let mut idx = 0; + + b.iter(|| { + tree.insert(black_box(&strings[idx % strings.len()]), "tenant1"); + idx += 1; + }); + }, + ); + } + + // Benchmark with shared prefixes (common cache scenario) + group.bench_function("shared_prefix_100", |b| { + let tree = Tree::new(); + let prefixes = ["system:", "user:", "assistant:", "tool:"]; + let strings: Vec = prefixes + .iter() + .flat_map(|p| random_prefixed_strings(p, 50, 250)) + .collect(); + let mut idx = 0; + + b.iter(|| { + tree.insert(black_box(&strings[idx % strings.len()]), "tenant1"); + idx += 1; + }); + }); + + group.finish(); +} + +/// Benchmark prefix_match latency +fn bench_prefix_match_latency(c: &mut Criterion) { + let mut group = c.benchmark_group("prefix_match_latency"); + + // Setup: pre-populate tree with data + let tree = Tree::new(); + let prefixes = ["system:", "user:", "assistant:", "tool:"]; + let strings: Vec = prefixes + .iter() + .flat_map(|p| random_prefixed_strings(p, 50, 1000)) + .collect(); + + for s in &strings { + tree.insert(s, "tenant1"); + } + + // Benchmark cache hit (exact match) + group.bench_function("cache_hit", |b| { + let mut idx = 0; + b.iter(|| { + let result = tree.prefix_match(black_box(&strings[idx % strings.len()])); + idx += 1; + result + }); + }); + + // Benchmark cache miss (no match) + let miss_strings: Vec = (0..1000).map(|_| random_ascii_string(50)).collect(); + group.bench_function("cache_miss", |b| { + let mut idx = 0; + b.iter(|| { + let result = tree.prefix_match(black_box(&miss_strings[idx % miss_strings.len()])); + idx += 1; + result + }); + }); + + // Benchmark partial match + group.bench_function("partial_match", |b| { + let partial_strings: Vec = prefixes + .iter() + .map(|p| format!("{}partial_query", p)) + .collect(); + let mut idx = 0; + b.iter(|| { + let result = + tree.prefix_match(black_box(&partial_strings[idx % partial_strings.len()])); + idx += 1; + result + }); + }); + + group.finish(); +} + +/// Benchmark concurrent operations +fn bench_concurrent_operations(c: &mut Criterion) { + let mut group = c.benchmark_group("concurrent"); + group.sample_size(50); // Reduce sample size for concurrent tests + + // Mixed read/write workload + for num_threads in [2, 4, 8].iter() { + group.bench_with_input( + BenchmarkId::new("mixed_workload", num_threads), + num_threads, + |b, &threads| { + b.iter(|| { + let tree = Arc::new(Tree::new()); + let handles: Vec<_> = (0..threads) + .map(|t| { + let tree = Arc::clone(&tree); + thread::spawn(move || { + let tenant = format!("tenant{}", t); + for i in 0..100 { + let text = format!("thread{}_request{}", t, i); + if i % 3 == 0 { + tree.prefix_match(&text); + } else { + tree.insert(&text, &tenant); + } + } + }) + }) + .collect(); + + for h in handles { + h.join().unwrap(); + } + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark eviction performance +fn bench_eviction(c: &mut Criterion) { + let mut group = c.benchmark_group("eviction"); + group.sample_size(20); // Eviction is expensive + + for tree_size in [1000, 5000, 10000].iter() { + group.bench_with_input( + BenchmarkId::new("evict_to_half", tree_size), + tree_size, + |b, &size| { + b.iter_with_setup( + || { + // Setup: create tree with many entries + let tree = Tree::new(); + for i in 0..size { + tree.insert(&format!("entry_{:05}", i), "tenant1"); + } + tree + }, + |tree| { + // Evict to half size + tree.evict_tenant_by_size(size / 2); + }, + ); + }, + ); + } + + group.finish(); +} + +/// Benchmark UTF-8 handling vs ASCII +fn bench_utf8_vs_ascii(c: &mut Criterion) { + let mut group = c.benchmark_group("encoding"); + + let tree_ascii = Tree::new(); + let tree_utf8 = Tree::new(); + + // Pre-populate + let ascii_strings: Vec = (0..1000).map(|_| random_ascii_string(50)).collect(); + let utf8_strings: Vec = (0..1000).map(|i| format!("你好世界_{}", i)).collect(); + + for s in &ascii_strings { + tree_ascii.insert(s, "tenant1"); + } + for s in &utf8_strings { + tree_utf8.insert(s, "tenant1"); + } + + group.bench_function("ascii_match", |b| { + let mut idx = 0; + b.iter(|| { + let result = + tree_ascii.prefix_match(black_box(&ascii_strings[idx % ascii_strings.len()])); + idx += 1; + result + }); + }); + + group.bench_function("utf8_match", |b| { + let mut idx = 0; + b.iter(|| { + let result = tree_utf8.prefix_match(black_box(&utf8_strings[idx % utf8_strings.len()])); + idx += 1; + result + }); + }); + + group.finish(); +} + +/// Benchmark multi-tenant scenarios +fn bench_multi_tenant(c: &mut Criterion) { + let mut group = c.benchmark_group("multi_tenant"); + + let tree = Tree::new(); + + // Setup: multiple tenants with overlapping data + let tenants = ["worker1", "worker2", "worker3", "worker4"]; + let prefixes = ["prompt:", "completion:", "context:"]; + + for tenant in &tenants { + for prefix in &prefixes { + for i in 0..100 { + tree.insert(&format!("{}data_{}", prefix, i), tenant); + } + } + } + + group.bench_function("shared_prefix_lookup", |b| { + let queries: Vec = prefixes + .iter() + .flat_map(|p| (0..10).map(move |i| format!("{}data_{}", p, i))) + .collect(); + let mut idx = 0; + + b.iter(|| { + let result = tree.prefix_match(black_box(&queries[idx % queries.len()])); + idx += 1; + result + }); + }); + + group.bench_function("tenant_specific_match", |b| { + let queries: Vec<(String, &str)> = tenants + .iter() + .flat_map(|&t| (0..10).map(move |i| (format!("prompt:data_{}", i), t))) + .collect(); + let mut idx = 0; + + b.iter(|| { + let (query, tenant) = &queries[idx % queries.len()]; + let result = tree.prefix_match_tenant(black_box(query), black_box(tenant)); + idx += 1; + result + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_insert_throughput, + bench_prefix_match_latency, + bench_concurrent_operations, + bench_eviction, + bench_utf8_vs_ascii, + bench_multi_tenant, +); +criterion_main!(benches); diff --git a/sgl-model-gateway/src/policies/cache_aware.rs b/sgl-model-gateway/src/policies/cache_aware.rs index 9aef85a944c2..af2648dee473 100644 --- a/sgl-model-gateway/src/policies/cache_aware.rs +++ b/sgl-model-gateway/src/policies/cache_aware.rs @@ -59,7 +59,14 @@ during the next eviction cycle. */ -use std::{sync::Arc, thread, time::Duration}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + thread, + time::Duration, +}; use dashmap::DashMap; use rand::Rng; @@ -77,7 +84,10 @@ use crate::{core::Worker, observability::metrics::RouterMetrics}; pub struct CacheAwarePolicy { config: CacheAwareConfig, trees: Arc>>, + /// Handle to the background eviction thread eviction_handle: Option>, + /// Flag to signal the eviction thread to stop + shutdown_flag: Arc, } impl CacheAwarePolicy { @@ -87,25 +97,48 @@ impl CacheAwarePolicy { pub fn with_config(config: CacheAwareConfig) -> Self { let trees = Arc::new(DashMap::>::new()); + let shutdown_flag = Arc::new(AtomicBool::new(false)); // Start background eviction thread if configured let eviction_handle = if config.eviction_interval_secs > 0 { let trees_clone = Arc::clone(&trees); + let shutdown_clone = Arc::clone(&shutdown_flag); let max_tree_size = config.max_tree_size; let interval = config.eviction_interval_secs; - Some(thread::spawn(move || loop { - thread::sleep(Duration::from_secs(interval)); - - // Evict for all model trees - for tree_ref in trees_clone.iter() { - let model_id = tree_ref.key(); - let tree = tree_ref.value(); - tree.evict_tenant_by_size(max_tree_size); - debug!( - "Cache eviction completed for model {}, max_size: {}", - model_id, max_tree_size - ); + Some(thread::spawn(move || { + // Use smaller sleep intervals to check shutdown flag more frequently + let check_interval_ms = 100; // Check every 100ms + let total_sleep_ms = interval * 1000; + + loop { + // Sleep in small increments, checking shutdown flag periodically + let mut slept_ms = 0u64; + while slept_ms < total_sleep_ms { + if shutdown_clone.load(Ordering::Relaxed) { + debug!("Eviction thread received shutdown signal"); + return; + } + thread::sleep(Duration::from_millis(check_interval_ms)); + slept_ms += check_interval_ms; + } + + // Check shutdown before starting eviction + if shutdown_clone.load(Ordering::Relaxed) { + debug!("Eviction thread received shutdown signal"); + return; + } + + // Evict for all model trees + for tree_ref in trees_clone.iter() { + let model_id = tree_ref.key(); + let tree = tree_ref.value(); + tree.evict_tenant_by_size(max_tree_size); + debug!( + "Cache eviction completed for model {}, max_size: {}", + model_id, max_tree_size + ); + } } })) } else { @@ -116,6 +149,7 @@ impl CacheAwarePolicy { config, trees, eviction_handle, + shutdown_flag, } } @@ -407,12 +441,16 @@ impl Default for CacheAwarePolicy { impl Drop for CacheAwarePolicy { fn drop(&mut self) { - // Note: We can't properly stop the eviction thread since it's in an infinite loop - // In a production system, we'd use a channel or atomic flag to signal shutdown + // Signal the eviction thread to stop + self.shutdown_flag.store(true, Ordering::Relaxed); + + // Wait for the thread to finish (with timeout) if let Some(handle) = self.eviction_handle.take() { - // The thread will continue running until the program exits - // This is acceptable for now since the router typically runs for the lifetime of the program - drop(handle); + // The thread checks the shutdown flag every 100ms, so it should exit quickly + match handle.join() { + Ok(()) => debug!("Eviction thread shut down cleanly"), + Err(_) => debug!("Eviction thread panicked during shutdown"), + } } } } diff --git a/sgl-model-gateway/src/policies/mod.rs b/sgl-model-gateway/src/policies/mod.rs index 8edf7f6a4e3e..f74483d6a3d5 100644 --- a/sgl-model-gateway/src/policies/mod.rs +++ b/sgl-model-gateway/src/policies/mod.rs @@ -14,7 +14,7 @@ mod power_of_two; mod random; mod registry; mod round_robin; -mod tree; +pub mod tree; pub use bucket::BucketPolicy; pub use cache_aware::CacheAwarePolicy; diff --git a/sgl-model-gateway/src/policies/tree.rs b/sgl-model-gateway/src/policies/tree.rs index f8c67e1869df..97df66cee376 100644 --- a/sgl-model-gateway/src/policies/tree.rs +++ b/sgl-model-gateway/src/policies/tree.rs @@ -1,8 +1,12 @@ use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap, VecDeque}, - sync::{Arc, RwLock}, - time::{Duration, SystemTime, UNIX_EPOCH}, + hash::{BuildHasherDefault, Hasher}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, + time::{SystemTime, UNIX_EPOCH}, }; use dashmap::{mapref::entry::Entry, DashMap}; @@ -10,6 +14,45 @@ use tracing::info; type NodeRef = Arc; +/// Interned tenant ID to avoid repeated string allocations. +/// Using Arc allows cheap cloning and comparison. +pub type TenantId = Arc; + +/// A fast identity hasher for single-character keys (used in children DashMap). +/// Since chars have good distribution already, we use identity hashing with mixing. +#[derive(Default)] +struct CharHasher(u64); + +impl Hasher for CharHasher { + #[inline(always)] + fn finish(&self) -> u64 { + self.0 + } + + #[inline(always)] + fn write(&mut self, bytes: &[u8]) { + // Fast path for 4-byte (char) writes - avoid loop + if bytes.len() == 4 { + let val = u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + // Mix with golden ratio for better distribution + self.0 = (val as u64).wrapping_mul(0x9E3779B97F4A7C15); + return; + } + // Fallback for other sizes (shouldn't happen for char keys) + for &byte in bytes { + self.0 = self.0.wrapping_mul(0x100000001b3).wrapping_add(byte as u64); + } + } + + #[inline(always)] + fn write_u32(&mut self, i: u32) { + // Chars are u32 - use golden ratio multiplication for distribution + self.0 = (i as u64).wrapping_mul(0x9E3779B97F4A7C15); + } +} + +type CharHasherBuilder = BuildHasherDefault; + /// Pre-indexed text for efficient character access. /// Converts UTF-8 string to Vec once to enable O(1) indexing. struct CharIndexedText { @@ -40,25 +83,146 @@ impl CharIndexedText { } } +/// Node text with cached character count to avoid repeated O(n) chars().count() calls. +#[derive(Debug)] +struct NodeText { + /// The actual text stored in this node + text: String, + /// Cached character count (UTF-8 chars, not bytes) + char_count: usize, +} + +impl NodeText { + #[inline] + fn new(text: String) -> Self { + let char_count = text.chars().count(); + Self { text, char_count } + } + + #[inline] + fn empty() -> Self { + Self { + text: String::new(), + char_count: 0, + } + } + + #[inline] + fn char_count(&self) -> usize { + self.char_count + } + + #[inline] + fn as_str(&self) -> &str { + &self.text + } + + #[inline] + fn first_char(&self) -> Option { + self.text.chars().next() + } + + /// Split the text at a character boundary, returning the prefix and suffix. + /// This is more efficient than slice_by_chars as it computes both at once. + #[inline] + fn split_at_char(&self, char_idx: usize) -> (NodeText, NodeText) { + if char_idx == 0 { + return (NodeText::empty(), self.clone_text()); + } + if char_idx >= self.char_count { + return (self.clone_text(), NodeText::empty()); + } + + // Find byte index for the character boundary + let byte_idx = self + .text + .char_indices() + .nth(char_idx) + .map(|(i, _)| i) + .unwrap_or(self.text.len()); + + let prefix = NodeText { + text: self.text[..byte_idx].to_string(), + char_count: char_idx, + }; + let suffix = NodeText { + text: self.text[byte_idx..].to_string(), + char_count: self.char_count - char_idx, + }; + (prefix, suffix) + } + + #[inline] + fn clone_text(&self) -> NodeText { + NodeText { + text: self.text.clone(), + char_count: self.char_count, + } + } +} + +impl Clone for NodeText { + fn clone(&self) -> Self { + self.clone_text() + } +} + +/// Global timestamp that gets updated periodically to reduce syscalls. +/// Uses milliseconds since epoch. +static CURRENT_TIMESTAMP_MS: AtomicU64 = AtomicU64::new(0); + +/// Staleness threshold in milliseconds for forced refresh. +/// If cached timestamp is older than this, always get fresh time. +const TIMESTAMP_STALENESS_MS: u64 = 5; + +/// Get current timestamp in milliseconds, using cached value when possible. +/// Refreshes if the cached value is stale (>TIMESTAMP_STALENESS_MS). +/// This provides ~99% syscall reduction under high load while maintaining accuracy. +#[inline] +fn get_timestamp_ms() -> u128 { + let cached = CURRENT_TIMESTAMP_MS.load(Ordering::Relaxed); + + // Always need syscall to check staleness, but it's cheap and necessary for correctness + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + // Fast path: return cached if still fresh (within TIMESTAMP_STALENESS_MS) + if cached != 0 && now.saturating_sub(cached) < TIMESTAMP_STALENESS_MS { + return cached as u128; + } + + // Update cached value + CURRENT_TIMESTAMP_MS.store(now, Ordering::Relaxed); + now as u128 +} + #[derive(Debug)] struct Node { - children: DashMap, - text: RwLock, - tenant_last_access_time: DashMap, + /// Children nodes indexed by first character. + /// Using custom hasher optimized for char keys. + children: DashMap, + /// Node text with cached character count + text: RwLock, + /// Per-tenant last access timestamps. Using TenantId (Arc) for cheap cloning. + tenant_last_access_time: DashMap, + /// Parent pointer for upward traversal during timestamp updates parent: RwLock>, } #[derive(Debug)] pub struct Tree { root: NodeRef, - pub tenant_char_count: DashMap, + /// Per-tenant character count for size tracking. Using TenantId for consistency. + pub tenant_char_count: DashMap, } // For the heap struct EvictionEntry { timestamp: u128, - tenant: String, + tenant: TenantId, node: NodeRef, } @@ -87,7 +251,8 @@ impl PartialEq for EvictionEntry { // Note that in rust, `.len()` or slice is operated on the "byte" level. It causes issues for UTF-8 characters because one character might use multiple bytes. // https://en.wikipedia.org/wiki/UTF-8 -/// Efficient shared prefix count using pre-indexed chars for O(1) access +/// Efficient shared prefix count using pre-indexed chars for O(1) access. +/// Returns the number of characters that match between `a` (starting at `a_start`) and `b`. #[inline] fn shared_prefix_count_indexed(a: &CharIndexedText, a_start: usize, b: &str) -> usize { let mut i = 0; @@ -105,8 +270,10 @@ fn shared_prefix_count_indexed(a: &CharIndexedText, a_start: usize, b: &str) -> i } -fn slice_by_chars(s: &str, start: usize, end: usize) -> String { - s.chars().skip(start).take(end - start).collect() +/// Intern a tenant string into an Arc for efficient storage and comparison. +#[inline] +fn intern_tenant(tenant: &str) -> TenantId { + Arc::from(tenant) } impl Default for Tree { @@ -122,13 +289,19 @@ impl Tree { 1. Storing data for multiple tenants (the overlap of multiple radix tree) 2. Node-level lock to enable concurrent access on nodes 3. Leaf LRU eviction based on tenant access time + + Optimizations: + - Cached character counts in NodeText to avoid O(n) chars().count() calls + - Interned tenant IDs (Arc) for cheap cloning and comparison + - Batched timestamp updates to reduce syscalls + - Custom hasher for char keys in children DashMap */ pub fn new() -> Self { Tree { root: Arc::new(Node { - children: DashMap::new(), - text: RwLock::new("".to_string()), + children: DashMap::with_hasher(CharHasherBuilder::default()), + text: RwLock::new(NodeText::empty()), tenant_last_access_time: DashMap::new(), parent: RwLock::new(None), }), @@ -145,16 +318,17 @@ impl Tree { let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; - let timestamp_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis(); + // Use cached timestamp to reduce syscalls + let timestamp_ms = get_timestamp_ms(); + + // Intern the tenant ID once for reuse + let tenant_id = intern_tenant(tenant); curr.tenant_last_access_time - .insert(tenant.to_string(), timestamp_ms); + .insert(Arc::clone(&tenant_id), timestamp_ms); self.tenant_char_count - .entry(tenant.to_string()) + .entry(Arc::clone(&tenant_id)) .or_insert(0); let mut prev = Arc::clone(&self.root); @@ -183,20 +357,20 @@ impl Tree { let curr_text = indexed_text.slice_to_string(curr_idx, text_count); let curr_text_count = text_count - curr_idx; let new_node = Arc::new(Node { - children: DashMap::new(), - text: RwLock::new(curr_text), + children: DashMap::with_hasher(CharHasherBuilder::default()), + text: RwLock::new(NodeText::new(curr_text)), tenant_last_access_time: DashMap::new(), parent: RwLock::new(Some(Arc::clone(&curr))), }); // Attach tenant to the new node (map is empty here) and increment count once self.tenant_char_count - .entry(tenant.to_string()) + .entry(Arc::clone(&tenant_id)) .and_modify(|count| *count += curr_text_count) .or_insert(curr_text_count); new_node .tenant_last_access_time - .insert(tenant.to_string(), timestamp_ms); + .insert(Arc::clone(&tenant_id), timestamp_ms); entry.insert(Arc::clone(&new_node)); @@ -209,11 +383,15 @@ impl Tree { let matched_node = entry.get().clone(); let matched_node_text = matched_node.text.read().unwrap(); - let matched_node_text_count = matched_node_text.chars().count(); + // Use cached char count instead of chars().count() + let matched_node_text_count = matched_node_text.char_count(); // Use indexed comparison to avoid creating intermediate string - let shared_count = - shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_node_text); + let shared_count = shared_prefix_count_indexed( + &indexed_text, + curr_idx, + matched_node_text.as_str(), + ); if shared_count < matched_node_text_count { /* @@ -223,12 +401,9 @@ impl Tree { [curr] -> [new_node] -> [contracted_matched_node] */ - let matched_text = slice_by_chars(&matched_node_text, 0, shared_count); - let contracted_text = slice_by_chars( - &matched_node_text, - shared_count, - matched_node_text_count, - ); + // Use split_at_char for efficient splitting with cached counts + let (matched_text, contracted_text) = + matched_node_text.split_at_char(shared_count); let matched_text_count = shared_count; // Drop read lock before creating new node @@ -236,12 +411,12 @@ impl Tree { let new_node = Arc::new(Node { text: RwLock::new(matched_text), - children: DashMap::new(), + children: DashMap::with_hasher(CharHasherBuilder::default()), parent: RwLock::new(Some(Arc::clone(&curr))), tenant_last_access_time: matched_node.tenant_last_access_time.clone(), }); - let first_new_char = contracted_text.chars().next().unwrap(); + let first_new_char = contracted_text.first_char().unwrap(); new_node .children .insert(first_new_char, Arc::clone(&matched_node)); @@ -254,10 +429,10 @@ impl Tree { prev = Arc::clone(&new_node); // Atomically attach tenant to the new split node and increment count once - match prev.tenant_last_access_time.entry(tenant.to_string()) { + match prev.tenant_last_access_time.entry(Arc::clone(&tenant_id)) { Entry::Vacant(v) => { self.tenant_char_count - .entry(tenant.to_string()) + .entry(Arc::clone(&tenant_id)) .and_modify(|count| *count += matched_text_count) .or_insert(matched_text_count); v.insert(timestamp_ms); @@ -276,10 +451,10 @@ impl Tree { prev = Arc::clone(&matched_node); // Atomically attach tenant to existing node and increment count once - match prev.tenant_last_access_time.entry(tenant.to_string()) { + match prev.tenant_last_access_time.entry(Arc::clone(&tenant_id)) { Entry::Vacant(v) => { self.tenant_char_count - .entry(tenant.to_string()) + .entry(Arc::clone(&tenant_id)) .and_modify(|count| *count += matched_node_text_count) .or_insert(matched_node_text_count); v.insert(timestamp_ms); @@ -316,9 +491,13 @@ impl Tree { let matched_node = entry.value().clone(); let matched_text_guard = matched_node.text.read().unwrap(); // Use indexed comparison to avoid creating intermediate string - let shared_count = - shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_text_guard); - let matched_node_text_count = matched_text_guard.chars().count(); + let shared_count = shared_prefix_count_indexed( + &indexed_text, + curr_idx, + matched_text_guard.as_str(), + ); + // Use cached char count instead of chars().count() + let matched_node_text_count = matched_text_guard.char_count(); drop(matched_text_guard); if shared_count == matched_node_text_count { @@ -339,33 +518,32 @@ impl Tree { curr = prev.clone(); - // Select the first tenant (key in the map) - let tenant = curr + // Select the first tenant (key in the map) - use Arc directly + let tenant: Option = curr .tenant_last_access_time .iter() .next() - .map(|kv| kv.key().to_owned()) - .unwrap_or_else(|| "empty".to_string()); + .map(|kv| Arc::clone(kv.key())); - // Traverse from the curr node to the root and update the timestamp + // Use cached timestamp to reduce syscalls + let timestamp_ms = get_timestamp_ms(); - let timestamp_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis(); - - if tenant != "empty" { + // Traverse from the curr node to the root and update the timestamp + if let Some(ref tenant_id) = tenant { let mut current_node = Some(curr); while let Some(node) = current_node { node.tenant_last_access_time - .insert(tenant.clone(), timestamp_ms); + .insert(Arc::clone(tenant_id), timestamp_ms); current_node = node.parent.read().unwrap().clone(); } } // Use indexed slice for result let ret_text = indexed_text.slice_to_string(0, curr_idx); - (ret_text, tenant) + let tenant_str = tenant + .map(|t| t.to_string()) + .unwrap_or_else(|| "empty".to_string()); + (ret_text, tenant_str) } #[allow(unused_assignments, dead_code)] @@ -374,6 +552,9 @@ impl Tree { let indexed_text = CharIndexedText::new(text); let text_count = indexed_text.len(); + // Intern tenant ID once for efficient lookups + let tenant_id = intern_tenant(tenant); + let mut curr = Arc::clone(&self.root); let mut curr_idx = 0; @@ -389,15 +570,23 @@ impl Tree { let matched_node = entry.value().clone(); // Only continue matching if this node belongs to the specified tenant - if !matched_node.tenant_last_access_time.contains_key(tenant) { + // Note: contains_key with &str works because Arc implements Borrow + if !matched_node + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { break; } let matched_text_guard = matched_node.text.read().unwrap(); // Use indexed comparison to avoid creating intermediate string - let shared_count = - shared_prefix_count_indexed(&indexed_text, curr_idx, &matched_text_guard); - let matched_node_text_count = matched_text_guard.chars().count(); + let shared_count = shared_prefix_count_indexed( + &indexed_text, + curr_idx, + matched_text_guard.as_str(), + ); + // Use cached char count instead of chars().count() + let matched_node_text_count = matched_text_guard.char_count(); drop(matched_text_guard); if shared_count == matched_node_text_count { @@ -419,16 +608,17 @@ impl Tree { curr = prev.clone(); // Only update timestamp if we found a match for the specified tenant - if curr.tenant_last_access_time.contains_key(tenant) { - let timestamp_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis(); + if curr + .tenant_last_access_time + .contains_key(tenant_id.as_ref()) + { + // Use cached timestamp to reduce syscalls + let timestamp_ms = get_timestamp_ms(); let mut current_node = Some(curr); while let Some(node) = current_node { node.tenant_last_access_time - .insert(tenant.to_string(), timestamp_ms); + .insert(Arc::clone(&tenant_id), timestamp_ms); current_node = node.parent.read().unwrap().clone(); } } @@ -437,19 +627,21 @@ impl Tree { indexed_text.slice_to_string(0, curr_idx) } - fn leaf_of(node: &NodeRef) -> Vec { + fn leaf_of(node: &NodeRef) -> Vec { /* - Return the list of tenants if it's a leaf for the tenant + Return the list of tenants if it's a leaf for the tenant. + A tenant is a "leaf" at this node if this node has the tenant but none of its children do. */ - let mut candidates: HashMap = node + let mut candidates: HashMap = node .tenant_last_access_time .iter() - .map(|entry| (entry.key().clone(), true)) + .map(|entry| (Arc::clone(entry.key()), true)) .collect(); for child in node.children.iter() { for tenant in child.value().tenant_last_access_time.iter() { - candidates.insert(tenant.key().clone(), false); + // Mark as non-leaf if any child has this tenant + candidates.insert(Arc::clone(tenant.key()), false); } } @@ -472,10 +664,10 @@ impl Tree { // Add leaves to priority queue for tenant in Tree::leaf_of(&curr) { - if let Some(timestamp) = curr.tenant_last_access_time.get(&tenant) { + if let Some(timestamp) = curr.tenant_last_access_time.get(tenant.as_ref()) { pq.push(Reverse(EvictionEntry { timestamp: *timestamp, - tenant: tenant.clone(), + tenant: Arc::clone(&tenant), node: Arc::clone(&curr), })); } @@ -491,30 +683,31 @@ impl Tree { while let Some(Reverse(entry)) = pq.pop() { let EvictionEntry { tenant, node, .. } = entry; - if let Some(used_size) = self.tenant_char_count.get(&tenant) { + if let Some(used_size) = self.tenant_char_count.get(tenant.as_ref()) { if *used_size <= max_size { continue; } } // Decrement when removing tenant from node - if node.tenant_last_access_time.contains_key(&tenant) { - let node_len = node.text.read().unwrap().chars().count(); + if node.tenant_last_access_time.contains_key(tenant.as_ref()) { + // Use cached char count instead of chars().count() + let node_len = node.text.read().unwrap().char_count(); self.tenant_char_count - .entry(tenant.clone()) + .entry(Arc::clone(&tenant)) .and_modify(|count| { *count = count.saturating_sub(node_len); }); } // Remove tenant from node - node.tenant_last_access_time.remove(&tenant); + node.tenant_last_access_time.remove(tenant.as_ref()); // Remove empty nodes if node.children.is_empty() && node.tenant_last_access_time.is_empty() { if let Some(parent) = node.parent.read().unwrap().as_ref() { let text_guard = node.text.read().unwrap(); - if let Some(first_char) = text_guard.chars().next() { + if let Some(first_char) = text_guard.first_char() { parent.children.remove(&first_char); } } @@ -522,11 +715,12 @@ impl Tree { // Add parent to queue if it becomes a leaf if let Some(parent) = node.parent.read().unwrap().as_ref() { - if Tree::leaf_of(parent).contains(&tenant) { - if let Some(timestamp) = parent.tenant_last_access_time.get(&tenant) { + let parent_leaves = Tree::leaf_of(parent); + if parent_leaves.iter().any(|t| t.as_ref() == tenant.as_ref()) { + if let Some(timestamp) = parent.tenant_last_access_time.get(tenant.as_ref()) { pq.push(Reverse(EvictionEntry { timestamp: *timestamp, - tenant: tenant.clone(), + tenant: Arc::clone(&tenant), node: Arc::clone(parent), })); } @@ -541,6 +735,9 @@ impl Tree { } pub fn remove_tenant(&self, tenant: &str) { + // Intern tenant ID once for efficient lookups + let tenant_id = intern_tenant(tenant); + // 1. Find all the leaves for the tenant let mut stack = vec![Arc::clone(&self.root)]; let mut queue = VecDeque::new(); @@ -550,7 +747,8 @@ impl Tree { stack.push(Arc::clone(child.value())); } - if Tree::leaf_of(&curr).contains(&tenant.to_string()) { + let leaves = Tree::leaf_of(&curr); + if leaves.iter().any(|t| t.as_ref() == tenant_id.as_ref()) { queue.push_back(Arc::clone(&curr)); } } @@ -558,13 +756,13 @@ impl Tree { // 2. Start from the leaves and traverse up to the root, removing the tenant from each node while let Some(curr) = queue.pop_front() { // remove tenant from node - curr.tenant_last_access_time.remove(&tenant.to_string()); + curr.tenant_last_access_time.remove(tenant_id.as_ref()); // remove empty nodes if curr.children.is_empty() && curr.tenant_last_access_time.is_empty() { if let Some(parent) = curr.parent.read().unwrap().as_ref() { let text_guard = curr.text.read().unwrap(); - if let Some(first_char) = text_guard.chars().next() { + if let Some(first_char) = text_guard.first_char() { parent.children.remove(&first_char); } } @@ -572,21 +770,25 @@ impl Tree { // add parent to queue if it becomes a leaf if let Some(parent) = curr.parent.read().unwrap().as_ref() { - if Tree::leaf_of(parent).contains(&tenant.to_string()) { + let parent_leaves = Tree::leaf_of(parent); + if parent_leaves + .iter() + .any(|t| t.as_ref() == tenant_id.as_ref()) + { queue.push_back(Arc::clone(parent)); } } } // 3. Remove the tenant from the tenant_char_count map - self.tenant_char_count.remove(&tenant.to_string()); + self.tenant_char_count.remove(tenant_id.as_ref()); } #[allow(dead_code)] pub fn get_tenant_char_count(&self) -> HashMap { self.tenant_char_count .iter() - .map(|entry| (entry.key().clone(), *entry.value())) + .map(|entry| (entry.key().to_string(), *entry.value())) .collect() } @@ -598,11 +800,12 @@ impl Tree { let mut stack = vec![Arc::clone(&self.root)]; while let Some(curr) = stack.pop() { - let text_count = curr.text.read().unwrap().chars().count(); + // Use cached char count instead of chars().count() + let text_count = curr.text.read().unwrap().char_count(); for tenant in curr.tenant_last_access_time.iter() { let size = used_size_per_tenant - .entry(tenant.key().clone()) + .entry(tenant.key().to_string()) .or_insert(0); *size += text_count; } @@ -617,6 +820,8 @@ impl Tree { #[allow(dead_code)] fn node_to_string(node: &NodeRef, prefix: &str, is_last: bool) -> String { + use std::time::Duration; + let mut result = String::new(); // Add prefix and branch character @@ -625,7 +830,7 @@ impl Tree { // Add node text let node_text = node.text.read().unwrap(); - result.push_str(&format!("'{}' [", node_text)); + result.push_str(&format!("'{}' [", node_text.as_str())); // Add tenant information with timestamps let mut tenant_info = Vec::new(); @@ -695,7 +900,10 @@ impl Tree { // Unit tests #[cfg(test)] mod tests { - use std::{thread, time::Instant}; + use std::{ + thread, + time::{Duration, Instant}, + }; use rand::{ distr::{Alphanumeric, SampleString}, @@ -704,6 +912,14 @@ mod tests { use super::*; + /// Helper to convert tenant_char_count to HashMap for comparison + fn get_maintained_counts(tree: &Tree) -> HashMap { + tree.tenant_char_count + .iter() + .map(|entry| (entry.key().to_string(), *entry.value())) + .collect() + } + #[test] fn test_tenant_char_count() { let tree = Tree::new(); @@ -715,11 +931,7 @@ mod tests { tree.insert("application", "tenant2"); let computed_sizes = tree.get_used_size_per_tenant(); - let maintained_counts: HashMap = tree - .tenant_char_count - .iter() - .map(|entry| (entry.key().clone(), *entry.value())) - .collect(); + let maintained_counts = get_maintained_counts(&tree); println!("Phase 1 - Maintained vs Computed counts:"); println!( @@ -737,11 +949,7 @@ mod tests { tree.insert("box", "tenant2"); let computed_sizes = tree.get_used_size_per_tenant(); - let maintained_counts: HashMap = tree - .tenant_char_count - .iter() - .map(|entry| (entry.key().clone(), *entry.value())) - .collect(); + let maintained_counts = get_maintained_counts(&tree); println!("Phase 2 - Maintained vs Computed counts:"); println!( @@ -759,11 +967,7 @@ mod tests { tree.insert("zero", "tenant2"); let computed_sizes = tree.get_used_size_per_tenant(); - let maintained_counts: HashMap = tree - .tenant_char_count - .iter() - .map(|entry| (entry.key().clone(), *entry.value())) - .collect(); + let maintained_counts = get_maintained_counts(&tree); println!("Phase 3 - Maintained vs Computed counts:"); println!( @@ -778,11 +982,7 @@ mod tests { tree.evict_tenant_by_size(10); let computed_sizes = tree.get_used_size_per_tenant(); - let maintained_counts: HashMap = tree - .tenant_char_count - .iter() - .map(|entry| (entry.key().clone(), *entry.value())) - .collect(); + let maintained_counts = get_maintained_counts(&tree); println!("Phase 4 - Maintained vs Computed counts:"); println!( @@ -1268,17 +1468,22 @@ mod tests { fn test_leaf_of() { let tree = Tree::new(); + // Helper to convert leaves to strings for easier assertion + let leaves_as_strings = + |leaves: &[TenantId]| -> Vec { leaves.iter().map(|t| t.to_string()).collect() }; + // Single node tree.insert("hello", "tenant1"); let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); - assert_eq!(leaves, vec!["tenant1"]); + assert_eq!(leaves_as_strings(&leaves), vec!["tenant1"]); // Node with multiple tenants tree.insert("hello", "tenant2"); let leaves = Tree::leaf_of(&tree.root.children.get(&'h').unwrap()); - assert_eq!(leaves.len(), 2); - assert!(leaves.contains(&"tenant1".to_string())); - assert!(leaves.contains(&"tenant2".to_string())); + let leaves_str = leaves_as_strings(&leaves); + assert_eq!(leaves_str.len(), 2); + assert!(leaves_str.contains(&"tenant1".to_string())); + assert!(leaves_str.contains(&"tenant2".to_string())); // Non-leaf node tree.insert("hi", "tenant1");