diff --git a/Cargo.lock b/Cargo.lock index dfbe2a590698..df62683a934b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1268,15 +1268,30 @@ dependencies = [ "which 4.4.2", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec 0.6.3", +] + [[package]] name = "bit-set" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ - "bit-vec", + "bit-vec 0.8.0", ] +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bit-vec" version = "0.8.0" @@ -2615,37 +2630,6 @@ dependencies = [ "syn 2.0.99", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.99", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.99", -] - [[package]] name = "digest" version = "0.10.7" @@ -2837,15 +2821,6 @@ version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" -[[package]] -name = "esaxx-rs" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" -dependencies = [ - "cc", -] - [[package]] name = "etcetera" version = "0.8.0" @@ -2905,13 +2880,24 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set 0.5.3", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + [[package]] name = "fancy-regex" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ - "bit-set", + "bit-set 0.8.0", "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -3460,7 +3446,7 @@ dependencies = [ "temp-env", "tempfile", "thiserror 1.0.69", - "tokenizers", + "tiktoken-rs", "tokio", "tokio-cron-scheduler", "tokio-stream", @@ -4440,15 +4426,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.1" @@ -5293,22 +5270,6 @@ dependencies = [ "libc", ] -[[package]] -name = "macro_rules_attribute" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" -dependencies = [ - "macro_rules_attribute-proc_macro", - "paste", -] - -[[package]] -name = "macro_rules_attribute-proc_macro" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" - [[package]] name = "malloc_buf" version = "0.0.6" @@ -5576,27 +5537,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "monostate" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee" -dependencies = [ - "monostate-impl", - "serde", -] - -[[package]] -name = "monostate-impl" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.99", -] - [[package]] name = "multimap" version = "0.10.1" @@ -6852,17 +6792,6 @@ dependencies = [ "rayon-core", ] -[[package]] -name = "rayon-cond" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" -dependencies = [ - "either", - "itertools 0.11.0", - "rayon", -] - [[package]] name = "rayon-core" version = "1.12.1" @@ -7817,18 +7746,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "spm_precompiled" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" -dependencies = [ - "base64 0.13.1", - "nom", - "serde", - "unicode-segmentation", -] - [[package]] name = "sqlparser" version = "0.49.0" @@ -8382,6 +8299,22 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44075987ee2486402f0808505dd65692163d243a337fc54363d49afac41087f6" +dependencies = [ + "anyhow", + "base64 0.21.7", + "bstr", + "fancy-regex 0.13.0", + "lazy_static", + "parking_lot", + "regex", + "rustc-hash 1.1.0", +] + [[package]] name = "time" version = "0.3.38" @@ -8459,38 +8392,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tokenizers" -version = "0.20.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" -dependencies = [ - "aho-corasick", - "derive_builder", - "esaxx-rs", - "getrandom 0.2.15", - "indicatif", - "itertools 0.12.1", - "lazy_static", - "log", - "macro_rules_attribute", - "monostate", - "onig", - "paste", - "rand 0.8.5", - "rayon", - "rayon-cond", - "regex", - "regex-syntax 0.8.5", - "serde", - "serde_json", - "spm_precompiled", - "thiserror 1.0.69", - "unicode-normalization-alignments", - "unicode-segmentation", - "unicode_categories", -] - [[package]] name = "tokio" version = "1.43.1" @@ -8866,7 +8767,7 @@ dependencies = [ "cfb", "chrono", "encoding_rs", - "fancy-regex", + "fancy-regex 0.14.0", "getrandom 0.2.15", "hmac", "html_parser", @@ -8899,15 +8800,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" -[[package]] -name = "unicode-normalization-alignments" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" -dependencies = [ - "smallvec", -] - [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -8926,12 +8818,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" -[[package]] -name = "unicode_categories" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" - [[package]] name = "uniffi" version = "0.29.2" diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index b76171574ef5..4a91d4288ff7 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -43,8 +43,8 @@ regex = "1.11.1" async-trait = "0.1" async-stream = "0.3" minijinja = "2.8.0" -tokenizers = "0.20.3" include_dir = "0.7.4" +tiktoken-rs = "0.6.0" chrono = { version = "0.4.38", features = ["serde"] } indoc = "2.0.5" nanoid = "0.4" diff --git a/crates/goose/benches/tokenization_benchmark.rs b/crates/goose/benches/tokenization_benchmark.rs index 708cc68bf516..85be9a0ea32f 100644 --- a/crates/goose/benches/tokenization_benchmark.rs +++ b/crates/goose/benches/tokenization_benchmark.rs @@ -3,18 +3,68 @@ use goose::token_counter::TokenCounter; fn benchmark_tokenization(c: &mut Criterion) { let lengths = [1_000, 5_000, 10_000, 50_000, 100_000, 124_000, 200_000]; - let tokenizer_names = ["Xenova--gpt-4o", "Xenova--claude-tokenizer"]; - - for tokenizer_name in tokenizer_names { - let counter = TokenCounter::new(tokenizer_name); - for &length in &lengths { - let text = "hello ".repeat(length); - c.bench_function(&format!("{}_{}_tokens", tokenizer_name, length), |b| { - b.iter(|| counter.count_tokens(black_box(&text))) - }); - } + + // Create a single token counter using the fixed o200k_base encoding + let counter = TokenCounter::new(); // Uses fixed o200k_base encoding + + for &length in &lengths { + let text = "hello ".repeat(length); + c.bench_function(&format!("o200k_base_{}_tokens", length), |b| { + b.iter(|| counter.count_tokens(black_box(&text))) + }); } } -criterion_group!(benches, benchmark_tokenization); +fn benchmark_async_tokenization(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let lengths = [1_000, 5_000, 10_000, 50_000, 100_000, 124_000, 200_000]; + + // Create an async token counter + let counter = rt.block_on(async { + goose::token_counter::create_async_token_counter() + .await + .unwrap() + }); + + for &length in &lengths { + let text = "hello ".repeat(length); + c.bench_function(&format!("async_o200k_base_{}_tokens", length), |b| { + b.iter(|| counter.count_tokens(black_box(&text))) + }); + } +} + +fn benchmark_cache_performance(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + + // Create an async token counter for cache testing + let counter = rt.block_on(async { + goose::token_counter::create_async_token_counter() + .await + .unwrap() + }); + + let test_texts = vec![ + "This is a test sentence for cache performance.", + "Another different sentence to test caching.", + "A third unique sentence for the benchmark.", + "This is a test sentence for cache performance.", // Repeat first one + "Another different sentence to test caching.", // Repeat second one + ]; + + c.bench_function("cache_hit_miss_pattern", |b| { + b.iter(|| { + for text in &test_texts { + counter.count_tokens(black_box(text)); + } + }) + }); +} + +criterion_group!( + benches, + benchmark_tokenization, + benchmark_async_tokenization, + benchmark_cache_performance +); criterion_main!(benches); diff --git a/crates/goose/build.rs b/crates/goose/build.rs deleted file mode 100644 index 1d9242661448..000000000000 --- a/crates/goose/build.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::error::Error; -use std::fs; -use std::path::Path; - -const BASE_DIR: &str = "../../tokenizer_files"; -const TOKENIZERS: &[&str] = &["Xenova/gpt-4o", "Xenova/claude-tokenizer"]; - -#[tokio::main] -async fn main() -> Result<(), Box> { - // Create base directory - fs::create_dir_all(BASE_DIR)?; - println!("cargo:rerun-if-changed=build.rs"); - println!("cargo:rerun-if-changed={BASE_DIR}"); - - for tokenizer_name in TOKENIZERS { - download_tokenizer(tokenizer_name).await?; - } - - Ok(()) -} - -async fn download_tokenizer(repo_id: &str) -> Result<(), Box> { - let dir_name = repo_id.replace('/', "--"); - let download_dir = format!("{BASE_DIR}/{dir_name}"); - let file_url = format!("https://huggingface.co/{repo_id}/resolve/main/tokenizer.json"); - let file_path = format!("{download_dir}/tokenizer.json"); - - // Create directory if it doesn't exist - fs::create_dir_all(&download_dir)?; - - // Check if file already exists - if Path::new(&file_path).exists() { - println!("Tokenizer for {repo_id} already exists, skipping..."); - return Ok(()); - } - - println!("Downloading tokenizer for {repo_id}..."); - - // Download the file - let response = reqwest::get(&file_url).await?; - if !response.status().is_success() { - return Err(format!( - "Failed to download tokenizer for {repo_id}, status: {}", - response.status() - ) - .into()); - } - - let content = response.bytes().await?; - fs::write(&file_path, content)?; - - println!("Downloaded {repo_id} to {file_path}"); - Ok(()) -} diff --git a/crates/goose/examples/async_token_counter_demo.rs b/crates/goose/examples/async_token_counter_demo.rs index 6b81f306454c..45aee116a505 100644 --- a/crates/goose/examples/async_token_counter_demo.rs +++ b/crates/goose/examples/async_token_counter_demo.rs @@ -39,7 +39,7 @@ async fn main() -> Result<(), Box> { // Test original TokenCounter let start = Instant::now(); - let sync_counter = TokenCounter::new("Xenova--gpt-4o"); + let sync_counter = TokenCounter::new(); let sync_init_time = start.elapsed(); let start = Instant::now(); @@ -56,7 +56,7 @@ async fn main() -> Result<(), Box> { // Test AsyncTokenCounter let start = Instant::now(); - let async_counter = create_async_token_counter("Xenova--gpt-4o").await?; + let async_counter = create_async_token_counter().await?; let async_init_time = start.elapsed(); let start = Instant::now(); @@ -89,20 +89,10 @@ async fn main() -> Result<(), Box> { assert_eq!(sync_total, async_total); assert_eq!(async_total, cached_total); - println!("\n✅ Key Improvements:"); - println!(" • No blocking runtime creation (eliminates deadlock risk)"); - println!(" • Global tokenizer caching with DashMap (lock-free concurrent access)"); - println!(" • Fast AHash for better cache performance"); - println!(" • Cache size management (prevents unbounded growth)"); println!( - " • Token result caching ({}x faster on repeated text)", + " Token result caching: {}x faster on cached text", async_count_time.as_nanos() / cached_time.as_nanos().max(1) ); - println!(" • Proper async patterns throughout"); - println!(" • Robust network failure handling with exponential backoff"); - println!(" • Download validation and corruption detection"); - println!(" • Progress reporting for large tokenizer downloads"); - println!(" • Smart retry logic (3 attempts, server errors only)"); Ok(()) } diff --git a/crates/goose/src/agents/context.rs b/crates/goose/src/agents/context.rs index d21c85d344fc..7ef1c267e3b7 100644 --- a/crates/goose/src/agents/context.rs +++ b/crates/goose/src/agents/context.rs @@ -16,10 +16,9 @@ impl Agent { messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded ) -> Result<(Vec, Vec), anyhow::Error> { let provider = self.provider().await?; - let token_counter = - create_async_token_counter(provider.get_model_config().tokenizer_name()) - .await - .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; + let token_counter = create_async_token_counter() + .await + .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; let target_context_limit = estimate_target_context_limit(provider); let token_counts = get_messages_token_counts_async(&token_counter, messages); @@ -54,10 +53,9 @@ impl Agent { messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded ) -> Result<(Vec, Vec), anyhow::Error> { let provider = self.provider().await?; - let token_counter = - create_async_token_counter(provider.get_model_config().tokenizer_name()) - .await - .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; + let token_counter = create_async_token_counter() + .await + .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; let target_context_limit = estimate_target_context_limit(provider.clone()); let (mut new_messages, mut new_token_counts) = diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index 772a9f683b18..75fe05b534c2 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -216,7 +216,7 @@ pub async fn summarize_messages_async( mod tests { use super::*; use crate::message::{Message, MessageContent}; - use crate::model::{ModelConfig, GPT_4O_TOKENIZER}; + use crate::model::ModelConfig; use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; use chrono::Utc; @@ -306,7 +306,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_single_chunk() { let provider = create_mock_provider(); - let token_counter = TokenCounter::new(GPT_4O_TOKENIZER); + let token_counter = TokenCounter::new(); let context_limit = 100; // Set a high enough limit to avoid chunking. let messages = create_test_messages(); @@ -342,7 +342,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_multiple_chunks() { let provider = create_mock_provider(); - let token_counter = TokenCounter::new(GPT_4O_TOKENIZER); + let token_counter = TokenCounter::new(); let context_limit = 30; let messages = create_test_messages(); @@ -378,7 +378,7 @@ mod tests { #[tokio::test] async fn test_summarize_messages_empty_input() { let provider = create_mock_provider(); - let token_counter = TokenCounter::new(GPT_4O_TOKENIZER); + let token_counter = TokenCounter::new(); let context_limit = 100; let messages: Vec = Vec::new(); diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index c8e28e47c947..60df7dc6ce61 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -4,10 +4,6 @@ use std::collections::HashMap; const DEFAULT_CONTEXT_LIMIT: usize = 128_000; -// Tokenizer names, used to infer from model name -pub const GPT_4O_TOKENIZER: &str = "Xenova--gpt-4o"; -pub const CLAUDE_TOKENIZER: &str = "Xenova--claude-tokenizer"; - // Define the model limits as a static HashMap for reuse static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| { let mut map = HashMap::new(); @@ -41,10 +37,6 @@ static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| pub struct ModelConfig { /// The name of the model to use pub model_name: String, - // Optional tokenizer name (corresponds to the sanitized HuggingFace tokenizer name) - // "Xenova/gpt-4o" -> "Xenova/gpt-4o" - // If not provided, best attempt will be made to infer from model name or default - pub tokenizer_name: String, /// Optional explicit context limit that overrides any defaults pub context_limit: Option, /// Optional temperature setting (0.0 - 1.0) @@ -73,7 +65,6 @@ impl ModelConfig { /// 3. Global default (128_000) (in get_context_limit) pub fn new(model_name: String) -> Self { let context_limit = Self::get_model_specific_limit(&model_name); - let tokenizer_name = Self::infer_tokenizer_name(&model_name); let toolshim = std::env::var("GOOSE_TOOLSHIM") .map(|val| val == "1" || val.to_lowercase() == "true") @@ -87,7 +78,6 @@ impl ModelConfig { Self { model_name, - tokenizer_name: tokenizer_name.to_string(), context_limit, temperature, max_tokens: None, @@ -96,15 +86,6 @@ impl ModelConfig { } } - fn infer_tokenizer_name(model_name: &str) -> &'static str { - if model_name.contains("claude") { - CLAUDE_TOKENIZER - } else { - // Default tokenizer - GPT_4O_TOKENIZER - } - } - /// Get model-specific context limit based on model name fn get_model_specific_limit(model_name: &str) -> Option { for (pattern, &limit) in MODEL_SPECIFIC_LIMITS.iter() { @@ -161,11 +142,6 @@ impl ModelConfig { self } - /// Get the tokenizer name - pub fn tokenizer_name(&self) -> &str { - &self.tokenizer_name - } - /// Get the context_limit for the current model /// If none are defined, use the DEFAULT_CONTEXT_LIMIT pub fn context_limit(&self) -> usize { diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index aa364ad3e992..eae586495b09 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -150,6 +150,7 @@ mod tests { use mcp_core::{content::TextContent, Role}; use std::env; + #[warn(dead_code)] #[derive(Clone)] struct MockTestProvider { name: String, diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index eb0db4136f1f..c74c4cbbe033 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -975,7 +975,6 @@ mod tests { // Test default medium reasoning effort for O3 model let model_config = ModelConfig { model_name: "gpt-4o".to_string(), - tokenizer_name: "gpt-4o".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), @@ -1007,7 +1006,6 @@ mod tests { // Test default medium reasoning effort for O1 model let model_config = ModelConfig { model_name: "o1".to_string(), - tokenizer_name: "o1".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), @@ -1040,7 +1038,6 @@ mod tests { // Test custom reasoning effort for O3 model let model_config = ModelConfig { model_name: "o3-mini-high".to_string(), - tokenizer_name: "o3-mini".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index fc7de71c82a7..402e09cfee99 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -853,7 +853,6 @@ mod tests { // Test default medium reasoning effort for O3 model let model_config = ModelConfig { model_name: "gpt-4o".to_string(), - tokenizer_name: "gpt-4o".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), @@ -885,7 +884,6 @@ mod tests { // Test default medium reasoning effort for O1 model let model_config = ModelConfig { model_name: "o1".to_string(), - tokenizer_name: "o1".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), @@ -918,7 +916,6 @@ mod tests { // Test custom reasoning effort for O3 model let model_config = ModelConfig { model_name: "o3-mini-high".to_string(), - tokenizer_name: "o3-mini".to_string(), context_limit: Some(4096), temperature: None, max_tokens: Some(1024), diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 7be9709ae767..c83ab8132430 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -1,275 +1,40 @@ use ahash::AHasher; use dashmap::DashMap; -use futures_util::stream::StreamExt; -use include_dir::{include_dir, Dir}; use mcp_core::Tool; -use std::error::Error; -use std::fs; use std::hash::{Hash, Hasher}; -use std::path::Path; use std::sync::Arc; -use tokenizers::tokenizer::Tokenizer; +use tiktoken_rs::CoreBPE; use tokio::sync::OnceCell; use crate::message::Message; -// The embedded directory with all possible tokenizer files. -// If one of them doesn’t exist, we’ll download it at startup. -static TOKENIZER_FILES: Dir = include_dir!("$CARGO_MANIFEST_DIR/../../tokenizer_files"); - -// Global tokenizer cache to avoid repeated downloads and loading -static TOKENIZER_CACHE: OnceCell>>> = OnceCell::const_new(); +// Global tokenizer instance to avoid repeated initialization +static TOKENIZER: OnceCell> = OnceCell::const_new(); // Cache size limits to prevent unbounded growth const MAX_TOKEN_CACHE_SIZE: usize = 10_000; -const MAX_TOKENIZER_CACHE_SIZE: usize = 50; /// Async token counter with caching capabilities pub struct AsyncTokenCounter { - tokenizer: Arc, + tokenizer: Arc, token_cache: Arc>, // content hash -> token count } /// Legacy synchronous token counter for backward compatibility pub struct TokenCounter { - tokenizer: Tokenizer, + tokenizer: Arc, } impl AsyncTokenCounter { /// Creates a new async token counter with caching - pub async fn new(tokenizer_name: &str) -> Result> { - // Initialize global cache if not already done - let cache = TOKENIZER_CACHE - .get_or_init(|| async { Arc::new(DashMap::new()) }) - .await; - - // Check cache first - DashMap allows concurrent reads - if let Some(tokenizer) = cache.get(tokenizer_name) { - return Ok(Self { - tokenizer: tokenizer.clone(), - token_cache: Arc::new(DashMap::new()), - }); - } - - // Try embedded first - let tokenizer = match Self::load_from_embedded(tokenizer_name) { - Ok(tokenizer) => Arc::new(tokenizer), - Err(_) => { - // Download async if not found - Arc::new(Self::download_and_load_async(tokenizer_name).await?) - } - }; - - // Cache the tokenizer with size management - if cache.len() >= MAX_TOKENIZER_CACHE_SIZE { - // Simple eviction: remove oldest entry - if let Some(entry) = cache.iter().next() { - let old_key = entry.key().clone(); - cache.remove(&old_key); - } - } - cache.insert(tokenizer_name.to_string(), tokenizer.clone()); - + pub async fn new() -> Result { + let tokenizer = get_tokenizer().await?; Ok(Self { tokenizer, token_cache: Arc::new(DashMap::new()), }) } - /// Load tokenizer bytes from the embedded directory - fn load_from_embedded(tokenizer_name: &str) -> Result> { - let tokenizer_file_path = format!("{}/tokenizer.json", tokenizer_name); - let file = TOKENIZER_FILES - .get_file(&tokenizer_file_path) - .ok_or_else(|| { - format!( - "Tokenizer file not found in embedded: {}", - tokenizer_file_path - ) - })?; - let contents = file.contents(); - let tokenizer = Tokenizer::from_bytes(contents) - .map_err(|e| format!("Failed to parse tokenizer bytes: {}", e))?; - Ok(tokenizer) - } - - /// Async download that doesn't block the runtime - async fn download_and_load_async( - tokenizer_name: &str, - ) -> Result> { - let local_dir = std::env::temp_dir().join(tokenizer_name); - let local_json_path = local_dir.join("tokenizer.json"); - - // Check if file exists - if !tokio::fs::try_exists(&local_json_path) - .await - .unwrap_or(false) - { - eprintln!("Downloading tokenizer: {}", tokenizer_name); - let repo_id = tokenizer_name.replace("--", "/"); - Self::download_tokenizer_async(&repo_id, &local_dir).await?; - } - - // Load from disk asynchronously - let file_content = tokio::fs::read(&local_json_path).await?; - let tokenizer = Tokenizer::from_bytes(&file_content) - .map_err(|e| format!("Failed to parse tokenizer: {}", e))?; - - Ok(tokenizer) - } - - /// Robust async download with retry logic and network failure handling - async fn download_tokenizer_async( - repo_id: &str, - download_dir: &std::path::Path, - ) -> Result<(), Box> { - tokio::fs::create_dir_all(download_dir).await?; - - let file_url = format!( - "https://huggingface.co/{}/resolve/main/tokenizer.json", - repo_id - ); - let file_path = download_dir.join("tokenizer.json"); - - // Check if partial/corrupted file exists and remove it - if file_path.exists() { - if let Ok(existing_bytes) = tokio::fs::read(&file_path).await { - if Self::is_valid_tokenizer_json(&existing_bytes) { - return Ok(()); // File is complete and valid - } - } - // Remove corrupted/incomplete file - let _ = tokio::fs::remove_file(&file_path).await; - } - - // Create enhanced HTTP client with timeouts - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(60)) - .connect_timeout(std::time::Duration::from_secs(15)) - .user_agent("goose-tokenizer/1.0") - .build()?; - - // Download with retry logic - let response = Self::download_with_retry(&client, &file_url, 3).await?; - - // Stream download with progress reporting for large files - let total_size = response.content_length(); - let mut stream = response.bytes_stream(); - let mut file = tokio::fs::File::create(&file_path).await?; - let mut downloaded = 0; - - use tokio::io::AsyncWriteExt; - - while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result?; - file.write_all(&chunk).await?; - downloaded += chunk.len(); - - // Progress reporting for large downloads - if let Some(total) = total_size { - if total > 1024 * 1024 && downloaded % (256 * 1024) == 0 { - // Report every 256KB for files >1MB - eprintln!( - "Downloaded {}/{} bytes ({:.1}%)", - downloaded, - total, - (downloaded as f64 / total as f64) * 100.0 - ); - } - } - } - - file.flush().await?; - - // Validate downloaded file - let final_bytes = tokio::fs::read(&file_path).await?; - if !Self::is_valid_tokenizer_json(&final_bytes) { - tokio::fs::remove_file(&file_path).await?; - return Err("Downloaded tokenizer file is invalid or corrupted".into()); - } - - eprintln!( - "Successfully downloaded tokenizer: {} ({} bytes)", - repo_id, downloaded - ); - Ok(()) - } - - /// Download with exponential backoff retry logic - async fn download_with_retry( - client: &reqwest::Client, - url: &str, - max_retries: u32, - ) -> Result> { - let mut delay = std::time::Duration::from_millis(200); - - for attempt in 0..=max_retries { - match client.get(url).send().await { - Ok(response) if response.status().is_success() => { - return Ok(response); - } - Ok(response) if response.status().is_server_error() => { - // Retry on 5xx errors (server issues) - if attempt < max_retries { - eprintln!( - "Server error {} on attempt {}/{}, retrying in {:?}", - response.status(), - attempt + 1, - max_retries + 1, - delay - ); - tokio::time::sleep(delay).await; - delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30)); // Cap at 30s - continue; - } - return Err(format!( - "Server error after {} retries: {}", - max_retries, - response.status() - ) - .into()); - } - Ok(response) => { - // Don't retry on 4xx errors (client errors like 404, 403) - return Err(format!("Client error: {} - {}", response.status(), url).into()); - } - Err(e) if attempt < max_retries => { - // Retry on network errors (timeout, connection refused, DNS, etc.) - eprintln!( - "Network error on attempt {}/{}: {}, retrying in {:?}", - attempt + 1, - max_retries + 1, - e, - delay - ); - tokio::time::sleep(delay).await; - delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30)); // Cap at 30s - continue; - } - Err(e) => { - return Err( - format!("Network error after {} retries: {}", max_retries, e).into(), - ); - } - } - } - unreachable!() - } - - /// Validate that the downloaded file is a valid tokenizer JSON - fn is_valid_tokenizer_json(bytes: &[u8]) -> bool { - // Basic validation: check if it's valid JSON and has tokenizer structure - if let Ok(json_str) = std::str::from_utf8(bytes) { - if let Ok(json_value) = serde_json::from_str::(json_str) { - // Check for basic tokenizer structure - return json_value.get("version").is_some() - || json_value.get("vocab").is_some() - || json_value.get("model").is_some(); - } - } - false - } - /// Count tokens with optimized caching pub fn count_tokens(&self, text: &str) -> usize { // Use faster AHash for better performance @@ -283,8 +48,8 @@ impl AsyncTokenCounter { } // Compute and cache result with size management - let encoding = self.tokenizer.encode(text, false).unwrap_or_default(); - let count = encoding.len(); + let tokens = self.tokenizer.encode_with_special_tokens(text); + let count = tokens.len(); // Manage cache size to prevent unbounded growth if self.token_cache.len() >= MAX_TOKEN_CACHE_SIZE { @@ -316,7 +81,6 @@ impl AsyncTokenCounter { let name = &tool.name; let description = &tool.description.trim_end_matches('.'); - // Optimize: count components separately to avoid string allocation // Note: the separator (:) is likely tokenized with adjacent tokens, so we use original approach for accuracy let line = format!("{}:{}", name, description); func_token_count += self.count_tokens(&line); @@ -427,99 +191,24 @@ impl AsyncTokenCounter { } } -impl TokenCounter { - /// Creates a new `TokenCounter` using the given HuggingFace tokenizer name. - /// - /// * `tokenizer_name` might look like "Xenova--gpt-4o" - /// or "Qwen--Qwen2.5-Coder-32B-Instruct", etc. - pub fn new(tokenizer_name: &str) -> Self { - match Self::load_from_embedded(tokenizer_name) { - Ok(tokenizer) => Self { tokenizer }, - Err(e) => { - println!( - "Tokenizer '{}' not found in embedded dir: {}", - tokenizer_name, e - ); - println!("Attempting to download tokenizer and load..."); - // Fallback to download tokenizer and load from disk - match Self::download_and_load(tokenizer_name) { - Ok(counter) => counter, - Err(e) => panic!("Failed to initialize tokenizer: {}", e), - } - } - } - } - - /// Load tokenizer bytes from the embedded directory (via `include_dir!`). - fn load_from_embedded(tokenizer_name: &str) -> Result> { - let tokenizer_file_path = format!("{}/tokenizer.json", tokenizer_name); - let file = TOKENIZER_FILES - .get_file(&tokenizer_file_path) - .ok_or_else(|| { - format!( - "Tokenizer file not found in embedded: {}", - tokenizer_file_path - ) - })?; - let contents = file.contents(); - let tokenizer = Tokenizer::from_bytes(contents) - .map_err(|e| format!("Failed to parse tokenizer bytes: {}", e))?; - Ok(tokenizer) - } - - /// Fallback: If not found in embedded, we look in `base_dir` on disk. - /// If not on disk, we download from Hugging Face, then load from disk. - fn download_and_load(tokenizer_name: &str) -> Result> { - let local_dir = std::env::temp_dir().join(tokenizer_name); - let local_json_path = local_dir.join("tokenizer.json"); - - // If the file doesn't already exist, we download from HF - if !Path::new(&local_json_path).exists() { - eprintln!("Tokenizer file not on disk, downloading…"); - let repo_id = tokenizer_name.replace("--", "/"); - // e.g. "Xenova--llama3-tokenizer" -> "Xenova/llama3-tokenizer" - Self::download_tokenizer(&repo_id, &local_dir)?; - } - - // Load from disk - let file_content = fs::read(&local_json_path)?; - let tokenizer = Tokenizer::from_bytes(&file_content) - .map_err(|e| format!("Failed to parse tokenizer after download: {}", e))?; - - Ok(Self { tokenizer }) +impl Default for TokenCounter { + fn default() -> Self { + Self::new() } +} - /// DEPRECATED: Use AsyncTokenCounter for new code - /// Download from Hugging Face into the local directory if not already present. - /// This method still blocks but is kept for backward compatibility. - fn download_tokenizer(repo_id: &str, download_dir: &Path) -> Result<(), Box> { - std::fs::create_dir_all(download_dir)?; - - let file_url = format!( - "https://huggingface.co/{}/resolve/main/tokenizer.json", - repo_id - ); - let file_path = download_dir.join("tokenizer.json"); - - // Use blocking reqwest client to avoid nested runtime - let client = reqwest::blocking::Client::new(); - let response = client.get(&file_url).send()?; - - if !response.status().is_success() { - let error_msg = format!("Failed to download tokenizer: status {}", response.status()); - return Err(Box::::from(error_msg)); - } - - let bytes = response.bytes()?; - std::fs::write(&file_path, bytes)?; - - Ok(()) +impl TokenCounter { + /// Creates a new `TokenCounter` using the fixed o200k_base encoding. + pub fn new() -> Self { + // Use blocking version of get_tokenizer + let tokenizer = get_tokenizer_blocking().expect("Failed to initialize tokenizer"); + Self { tokenizer } } /// Count tokens for a piece of text using our single tokenizer. pub fn count_tokens(&self, text: &str) -> usize { - let encoding = self.tokenizer.encode(text, false).unwrap(); - encoding.len() + let tokens = self.tokenizer.encode_with_special_tokens(text); + tokens.len() } pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize { @@ -596,7 +285,6 @@ impl TokenCounter { if let Some(content_text) = content.as_text() { num_tokens += self.count_tokens(content_text); } else if let Some(tool_request) = content.as_tool_request() { - // TODO: count tokens for tool request let tool_call = tool_request.tool_call.as_ref().unwrap(); let text = format!( "{}:{}:{}", @@ -641,49 +329,79 @@ impl TokenCounter { } } +/// Get the global tokenizer instance (async version) +/// Fixed encoding for all tokenization - using o200k_base for GPT-4o and o1 models +async fn get_tokenizer() -> Result, String> { + let tokenizer = TOKENIZER + .get_or_init(|| async { + match tiktoken_rs::o200k_base() { + Ok(bpe) => Arc::new(bpe), + Err(e) => panic!("Failed to initialize o200k_base tokenizer: {}", e), + } + }) + .await; + Ok(tokenizer.clone()) +} + +/// Get the global tokenizer instance (blocking version for backward compatibility) +fn get_tokenizer_blocking() -> Result, String> { + // For the blocking version, we need to handle the case where the tokenizer hasn't been initialized yet + if let Some(tokenizer) = TOKENIZER.get() { + return Ok(tokenizer.clone()); + } + + // Initialize the tokenizer synchronously + match tiktoken_rs::o200k_base() { + Ok(bpe) => { + let tokenizer = Arc::new(bpe); + // Try to set it in the OnceCell, but it's okay if another thread beat us to it + let _ = TOKENIZER.set(tokenizer.clone()); + Ok(tokenizer) + } + Err(e) => Err(format!("Failed to initialize o200k_base tokenizer: {}", e)), + } +} + /// Factory function for creating async token counters with proper error handling -pub async fn create_async_token_counter(tokenizer_name: &str) -> Result { - AsyncTokenCounter::new(tokenizer_name) - .await - .map_err(|e| format!("Failed to initialize tokenizer '{}': {}", tokenizer_name, e)) +pub async fn create_async_token_counter() -> Result { + AsyncTokenCounter::new().await } #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent}; // or however your `Message` is imported - use crate::model::{CLAUDE_TOKENIZER, GPT_4O_TOKENIZER}; + use crate::message::{Message, MessageContent}; use mcp_core::role::Role; use mcp_core::tool::Tool; use serde_json::json; #[test] - fn test_claude_tokenizer() { - let counter = TokenCounter::new(CLAUDE_TOKENIZER); + fn test_token_counter_basic() { + let counter = TokenCounter::new(); let text = "Hello, how are you?"; let count = counter.count_tokens(text); println!("Token count for '{}': {:?}", text, count); - // The old test expected 6 tokens - assert_eq!(count, 6, "Claude tokenizer token count mismatch"); + // With o200k_base encoding, this should give us a reasonable count + assert!(count > 0, "Token count should be greater than 0"); } #[test] - fn test_gpt_4o_tokenizer() { - let counter = TokenCounter::new(GPT_4O_TOKENIZER); + fn test_token_counter_simple_text() { + let counter = TokenCounter::new(); let text = "Hey there!"; let count = counter.count_tokens(text); println!("Token count for '{}': {:?}", text, count); - // The old test expected 3 tokens - assert_eq!(count, 3, "GPT-4o tokenizer token count mismatch"); + // With o200k_base encoding, this should give us a reasonable count + assert!(count > 0, "Token count should be greater than 0"); } #[test] fn test_count_chat_tokens() { - let counter = TokenCounter::new(GPT_4O_TOKENIZER); + let counter = TokenCounter::new(); let system_prompt = "You are a helpful assistant that can answer questions about the weather."; @@ -736,65 +454,31 @@ mod tests { let token_count_with_tools = counter.count_chat_tokens(system_prompt, &messages, &tools); println!("Total tokens with tools: {}", token_count_with_tools); - // The old test used 56 / 124 for GPT-4o. Adjust if your actual tokenizer changes - assert_eq!(token_count_without_tools, 56); - assert_eq!(token_count_with_tools, 124); - } - - #[test] - #[should_panic] - fn test_panic_if_provided_tokenizer_doesnt_exist() { - // This should panic because the tokenizer doesn't exist - // in the embedded directory and the download fails - - TokenCounter::new("nonexistent-tokenizer"); - } - - // Optional test to confirm that fallback download works if not found in embedded: - // Ignored cause this actually downloads a tokenizer from Hugging Face - #[test] - #[ignore] - fn test_download_tokenizer_successfully_if_not_embedded() { - let non_embedded_key = "openai-community/gpt2"; - let counter = TokenCounter::new(non_embedded_key); - - // If it downloads successfully, we can do a quick count to ensure it's valid - let text = "print('hello world')"; - let count = counter.count_tokens(text); - println!( - "Downloaded tokenizer, token count for '{}': {}", - text, count + // Basic sanity checks - with o200k_base the exact counts may differ from the old tokenizer + assert!( + token_count_without_tools > 0, + "Should have some tokens without tools" + ); + assert!( + token_count_with_tools > token_count_without_tools, + "Should have more tokens with tools" ); - - // https://tiktokenizer.vercel.app/?model=gpt2 - assert!(count == 5, "Expected 5 tokens from downloaded tokenizer"); } #[tokio::test] - async fn test_async_claude_tokenizer() { - let counter = create_async_token_counter(CLAUDE_TOKENIZER).await.unwrap(); + async fn test_async_token_counter() { + let counter = create_async_token_counter().await.unwrap(); let text = "Hello, how are you?"; let count = counter.count_tokens(text); println!("Async token count for '{}': {:?}", text, count); - assert_eq!(count, 6, "Async Claude tokenizer token count mismatch"); - } - - #[tokio::test] - async fn test_async_gpt_4o_tokenizer() { - let counter = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); - - let text = "Hey there!"; - let count = counter.count_tokens(text); - println!("Async token count for '{}': {:?}", text, count); - - assert_eq!(count, 3, "Async GPT-4o tokenizer token count mismatch"); + assert!(count > 0, "Async token count should be greater than 0"); } #[tokio::test] async fn test_async_token_caching() { - let counter = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); + let counter = create_async_token_counter().await.unwrap(); let text = "This is a test for caching functionality"; @@ -815,7 +499,7 @@ mod tests { #[tokio::test] async fn test_async_count_chat_tokens() { - let counter = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); + let counter = create_async_token_counter().await.unwrap(); let system_prompt = "You are a helpful assistant that can answer questions about the weather."; @@ -871,28 +555,20 @@ mod tests { let token_count_with_tools = counter.count_chat_tokens(system_prompt, &messages, &tools); println!("Async total tokens with tools: {}", token_count_with_tools); - // Should match the synchronous version - assert_eq!(token_count_without_tools, 56); - assert_eq!(token_count_with_tools, 124); - } - - #[tokio::test] - async fn test_async_tokenizer_caching() { - // Create two counters with the same tokenizer name - let counter1 = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); - let counter2 = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); - - // Both should work and give same results (tokenizer is cached globally) - let text = "Test tokenizer caching"; - let count1 = counter1.count_tokens(text); - let count2 = counter2.count_tokens(text); - - assert_eq!(count1, count2); + // Basic sanity checks + assert!( + token_count_without_tools > 0, + "Should have some tokens without tools" + ); + assert!( + token_count_with_tools > token_count_without_tools, + "Should have more tokens with tools" + ); } #[tokio::test] async fn test_async_cache_management() { - let counter = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); + let counter = create_async_token_counter().await.unwrap(); // Add some items to cache counter.count_tokens("First text"); @@ -915,9 +591,7 @@ mod tests { async fn test_concurrent_token_counter_creation() { // Test concurrent creation of token counters to verify no race conditions let handles: Vec<_> = (0..10) - .map(|_| { - tokio::spawn(async { create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap() }) - }) + .map(|_| tokio::spawn(async { create_async_token_counter().await.unwrap() })) .collect(); let counters: Vec<_> = futures::future::join_all(handles) @@ -937,7 +611,7 @@ mod tests { #[tokio::test] async fn test_cache_eviction_behavior() { - let counter = create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap(); + let counter = create_async_token_counter().await.unwrap(); // Fill cache beyond normal size to test eviction let mut cached_texts = Vec::new(); @@ -959,17 +633,9 @@ mod tests { assert_eq!(counter.cache_size(), start_size); } - #[tokio::test] - async fn test_async_error_handling() { - // Test with invalid tokenizer name - let result = create_async_token_counter("invalid/nonexistent-tokenizer").await; - assert!(result.is_err(), "Should fail with invalid tokenizer name"); - } - #[tokio::test] async fn test_concurrent_cache_operations() { - let counter = - std::sync::Arc::new(create_async_token_counter(GPT_4O_TOKENIZER).await.unwrap()); + let counter = std::sync::Arc::new(create_async_token_counter().await.unwrap()); // Test concurrent token counting operations let handles: Vec<_> = (0..20) @@ -999,90 +665,25 @@ mod tests { } #[test] - fn test_tokenizer_json_validation() { - // Test valid tokenizer JSON - let valid_json = r#"{"version": "1.0", "model": {"type": "BPE"}}"#; - assert!(AsyncTokenCounter::is_valid_tokenizer_json( - valid_json.as_bytes() - )); - - let valid_json2 = r#"{"vocab": {"hello": 1, "world": 2}}"#; - assert!(AsyncTokenCounter::is_valid_tokenizer_json( - valid_json2.as_bytes() - )); - - // Test invalid JSON - let invalid_json = r#"{"incomplete": true"#; - assert!(!AsyncTokenCounter::is_valid_tokenizer_json( - invalid_json.as_bytes() - )); - - // Test valid JSON but not tokenizer structure - let wrong_structure = r#"{"random": "data", "not": "tokenizer"}"#; - assert!(!AsyncTokenCounter::is_valid_tokenizer_json( - wrong_structure.as_bytes() - )); - - // Test binary data - let binary_data = [0xFF, 0xFE, 0x00, 0x01]; - assert!(!AsyncTokenCounter::is_valid_tokenizer_json(&binary_data)); - - // Test empty data - assert!(!AsyncTokenCounter::is_valid_tokenizer_json(&[])); - } + fn test_tokenizer_consistency() { + // Test that both sync and async versions give the same results + let sync_counter = TokenCounter::new(); + let text = "This is a test for tokenizer consistency"; + let sync_count = sync_counter.count_tokens(text); - #[tokio::test] - async fn test_download_with_retry_logic() { - // This test would require mocking HTTP responses - // For now, we test the retry logic structure by verifying the function exists - // In a full test suite, you'd use wiremock or similar to simulate failures - - // Test that the function exists and has the right signature - let client = reqwest::Client::new(); + // Test that the tokenizer is working correctly + assert!(sync_count > 0, "Sync tokenizer should produce tokens"); - // Test with a known bad URL to verify error handling - let result = - AsyncTokenCounter::download_with_retry(&client, "https://httpbin.org/status/404", 1) - .await; + // Test with different text lengths + let short_text = "Hi"; + let long_text = "This is a much longer text that should produce significantly more tokens than the short text"; - assert!(result.is_err(), "Should fail with 404 error"); + let short_count = sync_counter.count_tokens(short_text); + let long_count = sync_counter.count_tokens(long_text); - let error_msg = result.unwrap_err().to_string(); assert!( - error_msg.contains("Client error: 404"), - "Should contain client error message" + short_count < long_count, + "Longer text should have more tokens" ); } - - #[tokio::test] - async fn test_network_resilience_with_timeout() { - // Test timeout handling with a slow endpoint - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_millis(100)) // Very short timeout - .build() - .unwrap(); - - // Use httpbin delay endpoint that takes longer than our timeout - let result = AsyncTokenCounter::download_with_retry( - &client, - "https://httpbin.org/delay/1", // 1 second delay, but 100ms timeout - 1, - ) - .await; - - assert!(result.is_err(), "Should timeout and fail"); - } - - #[tokio::test] - async fn test_successful_download_retry() { - // Test successful download after simulated retry - let client = reqwest::Client::new(); - - // Use a reliable endpoint that should succeed - let result = - AsyncTokenCounter::download_with_retry(&client, "https://httpbin.org/status/200", 2) - .await; - - assert!(result.is_ok(), "Should succeed with 200 status"); - } }