Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions sgl-model-gateway/src/tokenizer/cache/l0.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! L0 Cache: Whole-string exact match cache
//!
//! This is the simplest and most effective cache layer.
//! Key: input string → Value: full encoding result
//! Key: input string → Value: full encoding result (Arc-wrapped for zero-copy cache hits)
//!
//! Expected hit rate: 60-90% for workloads with repeated system prompts

Expand All @@ -15,9 +15,10 @@ use dashmap::DashMap;
use super::super::traits::Encoding;

/// L0 cache implementation using DashMap for lock-free reads
/// Uses Arc<Encoding> internally to provide zero-copy cache hits
pub struct L0Cache {
/// The cache map: input string → encoding
map: Arc<DashMap<String, Encoding>>,
/// The cache map: input string → Arc-wrapped encoding for cheap cloning
map: Arc<DashMap<String, Arc<Encoding>>>,
/// Maximum number of entries before eviction
max_entries: usize,
/// Cache hit counter
Expand All @@ -37,12 +38,14 @@ impl L0Cache {
}
}

/// Get an encoding from the cache
pub fn get(&self, key: &str) -> Option<Encoding> {
/// Get an encoding from the cache (returns Arc for zero-copy access)
#[inline]
pub fn get(&self, key: &str) -> Option<Arc<Encoding>> {
match self.map.get(key) {
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(entry.value().clone())
// Arc::clone is cheap (just increment reference count)
Some(Arc::clone(entry.value()))
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
Expand All @@ -65,6 +68,17 @@ impl L0Cache {
}
}

self.map.insert(key, Arc::new(value));
}

/// Insert a pre-wrapped Arc encoding into the cache (avoids double-wrapping)
pub fn insert_arc(&self, key: String, value: Arc<Encoding>) {
if self.map.len() >= self.max_entries {
let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
if let Some(k) = key_to_remove {
self.map.remove(&k);
}
}
self.map.insert(key, value);
Comment on lines +75 to 82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The eviction logic here is nearly identical to the logic in the insert method. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, consider extracting this shared logic into a private helper function. This function could be called from both insert and insert_arc.

}

Expand Down Expand Up @@ -139,7 +153,7 @@ mod tests {
// Insert
cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));

// Hit
// Hit - now returns Arc<Encoding>
let result = cache.get("hello");
assert!(result.is_some());
assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
Expand Down Expand Up @@ -217,4 +231,17 @@ mod tests {
// Should have 10 entries
assert_eq!(cache.len(), 10);
}

#[test]
fn test_arc_reuse() {
// Test that multiple gets return the same Arc (reference counting)
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));

let arc1 = cache.get("test").unwrap();
let arc2 = cache.get("test").unwrap();

// Both should point to the same allocation
assert!(Arc::ptr_eq(&arc1, &arc2));
}
}
21 changes: 12 additions & 9 deletions sgl-model-gateway/src/tokenizer/cache/l1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usi
}

/// A cached prefix entry
/// Uses Arc<[TokenIdType]> for zero-copy access to tokens
#[derive(Debug, Clone)]
struct CachedPrefix {
/// The pre-computed token IDs for this prefix
tokens: Vec<TokenIdType>,
/// The pre-computed token IDs for this prefix (Arc for zero-copy cloning)
tokens: Arc<[TokenIdType]>,
/// Last access timestamp (for LRU eviction)
last_accessed: Arc<AtomicU64>,
/// Size in bytes (for memory tracking during eviction)
Expand Down Expand Up @@ -127,6 +128,7 @@ impl L1Cache {
/// Returns (cached_tokens, byte_offset) if found
///
/// Uses pre-computed tokens cached during insertion.
/// Returns Vec<TokenIdType> as the caller needs to extend it with suffix tokens.
pub fn longest_prefix_match(
&self,
input: &str,
Expand Down Expand Up @@ -154,7 +156,8 @@ impl L1Cache {
entry.last_accessed.store(timestamp, Ordering::Relaxed);

self.hits.fetch_add(1, Ordering::Relaxed);
return Some((entry.tokens.clone(), boundary_pos));
// Convert Arc<[T]> to Vec<T> - caller will extend with suffix tokens
return Some((entry.tokens.to_vec(), boundary_pos));
}
}

Expand All @@ -181,7 +184,7 @@ impl L1Cache {
}

// Calculate how much memory we need and tokenize each prefix
let mut entries_to_insert = Vec::new();
let mut entries_to_insert = Vec::with_capacity(boundaries.len());
for &boundary_pos in &boundaries {
// Extract prefix up to this special token boundary
let prefix = &input[0..boundary_pos];
Expand All @@ -192,7 +195,8 @@ impl L1Cache {
// Re-tokenize the prefix for guaranteed correctness
// This is the only way to know the exact token boundaries
let prefix_encoding = tokenizer.encode(prefix)?;
let prefix_tokens = prefix_encoding.token_ids().to_vec();
// Convert to Arc<[TokenIdType]> for zero-copy sharing
let prefix_tokens: Arc<[TokenIdType]> = prefix_encoding.token_ids().into();

// Size = text bytes + token storage
let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
Expand All @@ -213,14 +217,13 @@ impl L1Cache {
}

// Insert all entries
let current_timestamp = self.access_counter.load(Ordering::Relaxed);
for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;

let cached = CachedPrefix {
tokens: prefix_tokens,
last_accessed: Arc::new(AtomicU64::new(
self.access_counter.load(Ordering::Relaxed),
)),
tokens: prefix_tokens, // Already Arc<[TokenIdType]>
last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
size_bytes,
};

Expand Down
31 changes: 17 additions & 14 deletions sgl-model-gateway/src/tokenizer/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,25 @@ impl CachedTokenizer {

impl Encoder for CachedTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// Collect special tokens once if L1 is enabled (avoid redundant allocation)
let special_tokens: Option<Vec<&str>> = self.l1.as_ref().map(|_| {
self.special_token_strings
.iter()
.map(|s| s.as_str())
.collect()
});

// L0 cache lookup (exact match)
// L0 cache lookup (exact match) - returns Arc<Encoding> for zero-copy
if let Some(l0) = &self.l0 {
if let Some(cached) = l0.get(input) {
return Ok(cached);
// Unwrap the Arc - since Encoding is Clone, we can return the inner value
// For callers who need the tokens, they can access via token_ids() which is &[u32]
return Ok((*cached).clone());
}
}

// L1 cache lookup (prefix match at special token boundaries)
if let Some(l1) = &self.l1 {
let tokens = special_tokens.as_ref().unwrap();
// Use pre-computed special tokens refs (avoids allocation per call)
let tokens: Vec<&str> = self
.special_token_strings
.iter()
.map(|s| s.as_str())
.collect();

if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, tokens) {
if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
// We have a prefix match - tokenize the suffix
let suffix = &input[prefix_len..];
if !suffix.is_empty() {
Expand Down Expand Up @@ -216,8 +215,12 @@ impl Encoder for CachedTokenizer {
// Cache in L1 at special token boundaries
// Re-tokenizes prefixes for correctness (optimized for high prefix reuse)
if let Some(l1) = &self.l1 {
let tokens = special_tokens.as_ref().unwrap();
let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), tokens);
let tokens: Vec<&str> = self
.special_token_strings
.iter()
.map(|s| s.as_str())
.collect();
let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens);
// Ignore errors in cache insertion - cache is best-effort
}

Expand Down
26 changes: 20 additions & 6 deletions sgl-model-gateway/src/tokenizer/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,26 @@ fn is_likely_json(buffer: &[u8]) -> bool {
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
// SentencePiece models often start with specific patterns
// This is a simplified check
buffer.len() >= 12
&& (buffer.starts_with(b"\x0a\x09")
|| buffer.starts_with(b"\x08\x00")
|| buffer.windows(4).any(|w| w == b"<unk")
|| buffer.windows(4).any(|w| w == b"<s>")
|| buffer.windows(4).any(|w| w == b"</s>"))
if buffer.len() < 12 {
return false;
}

// Check header patterns first (cheap)
if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
return true;
}

// Single-pass scan for special token markers
// Instead of multiple windows() calls, scan once looking for all patterns
let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
for window in buffer.windows(4) {
for pattern in patterns {
if window.starts_with(pattern) {
return true;
}
}
}
false
}

/// Helper function to discover chat template files in a directory
Expand Down
8 changes: 8 additions & 0 deletions sgl-model-gateway/src/tokenizer/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ impl Sequence {
}

/// Check if the sequence is empty
#[inline]
pub fn is_empty(&self) -> bool {
self.token_ids.is_empty()
}

/// Get the length of the sequence
#[inline]
pub fn len(&self) -> usize {
self.token_ids.len()
}
Expand All @@ -111,6 +113,7 @@ impl Sequence {

/// Append a single token to the sequence and return newly decoded text
/// Based on HuggingFace TGI incremental decoding
#[inline]
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
// Store the old read offset before adding the new token
let old_read_offset = self.read_offset;
Expand Down Expand Up @@ -165,11 +168,13 @@ impl Sequence {
}

/// Get a reference to the tokenizer
#[inline]
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
&self.tokenizer
}

/// Get the current token ids
#[inline]
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
Expand All @@ -181,16 +186,19 @@ impl Sequence {
}

/// Get the prefix offset
#[inline]
pub fn prefix_offset(&self) -> usize {
self.prefix_offset
}

/// Get the read offset
#[inline]
pub fn read_offset(&self) -> usize {
self.read_offset
}

/// Get whether special tokens are skipped during decoding
#[inline]
pub fn skip_special_tokens(&self) -> bool {
self.skip_special_tokens
}
Expand Down
22 changes: 12 additions & 10 deletions sgl-model-gateway/src/tokenizer/stop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ impl StopSequenceDecoder {
if self.config.stop_tokens.contains(&token_id) {
self.stopped = true;

// Flush any jailed text before stopping
// Flush any jailed text before stopping - use mem::take to avoid clone
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
return Ok(SequenceDecoderOutput::StoppedWithText(std::mem::take(
&mut self.jail_buffer,
)));
}
return Ok(SequenceDecoderOutput::Stopped);
}
Expand Down Expand Up @@ -186,8 +186,10 @@ impl StopSequenceDecoder {

if let Some(split_pos) = best_split_pos {
// Hold the partial match, flush the rest
// Drain [0..split_pos] as output, keep [split_pos..] in jail_buffer
let to_output = self.jail_buffer.drain(..split_pos).collect::<String>();
// Use split_off for zero-copy: keeps [0..split_pos] in place, returns [split_pos..]
// Then swap so we output the prefix and keep the suffix
let suffix = self.jail_buffer.split_off(split_pos);
let to_output = std::mem::replace(&mut self.jail_buffer, suffix);

if to_output.is_empty() {
Ok(SequenceDecoderOutput::Held)
Expand All @@ -210,7 +212,8 @@ impl StopSequenceDecoder {
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::new();
// Pre-allocate with exact capacity to avoid reallocations
let mut outputs = Vec::with_capacity(token_ids.len());
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
}
Expand All @@ -220,9 +223,8 @@ impl StopSequenceDecoder {
/// Flush any held text
pub fn flush(&mut self) -> SequenceDecoderOutput {
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
SequenceDecoderOutput::Text(output)
// Use mem::take to avoid clone - transfers ownership and leaves empty string
SequenceDecoderOutput::Text(std::mem::take(&mut self.jail_buffer))
} else {
SequenceDecoderOutput::Text(String::new())
}
Expand Down
4 changes: 3 additions & 1 deletion sgl-model-gateway/src/tokenizer/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl DecodeStream {

/// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk.
#[inline]
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
self.all_token_ids.push(id);

Expand Down Expand Up @@ -71,7 +72,8 @@ impl DecodeStream {

/// Process multiple tokens at once
pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
let mut chunks = Vec::new();
// Pre-allocate with capacity - most tokens produce output
let mut chunks = Vec::with_capacity(token_ids.len());

for &token_id in token_ids {
if let Some(text) = self.step(token_id)? {
Expand Down
1 change: 1 addition & 0 deletions sgl-model-gateway/src/tokenizer/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub enum Encoding {

impl Encoding {
/// Returns a reference to token IDs - zero-copy operation
#[inline]
pub fn token_ids(&self) -> &[TokenIdType] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Expand Down
Loading