diff --git a/crates/goose/src/dictation/providers.rs b/crates/goose/src/dictation/providers.rs index 37d864107c8f..d1d88202dcd6 100644 --- a/crates/goose/src/dictation/providers.rs +++ b/crates/goose/src/dictation/providers.rs @@ -1,7 +1,7 @@ use crate::config::Config; use crate::dictation::whisper::LOCAL_WHISPER_MODEL_CONFIG_KEY; use crate::providers::api_client::{ApiClient, AuthMethod}; -use anyhow::{Context, Result}; +use anyhow::Result; use serde::{Deserialize, Serialize}; use std::sync::Mutex; use std::time::Duration; @@ -9,13 +9,10 @@ use utoipa::ToSchema; const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); -// Global lazy-initialized transcriber to reuse the loaded model -// Stores (model_path, transcriber) to detect when model changes static LOCAL_TRANSCRIBER: once_cell::sync::Lazy< Mutex>, > = once_cell::sync::Lazy::new(|| Mutex::new(None)); -// Bundled tokenizer JSON (2.4MB) const WHISPER_TOKENIZER_JSON: &str = include_str!("whisper_data/tokens.json"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ToSchema)] @@ -85,7 +82,7 @@ pub fn get_provider_def(provider: DictationProvider) -> &'static DictationProvid PROVIDERS .iter() .find(|def| def.provider == provider) - .unwrap() // Safe because all enum variants are in PROVIDERS + .unwrap() } pub fn is_configured(provider: DictationProvider) -> bool { @@ -106,9 +103,7 @@ pub fn is_configured(provider: DictationProvider) -> bool { } pub async fn transcribe_local(audio_bytes: Vec) -> Result { - // Run transcription in a blocking task to avoid blocking the async runtime tokio::task::spawn_blocking(move || { - // Get model ID from config let config = Config::global(); let model_id = config .get(LOCAL_WHISPER_MODEL_CONFIG_KEY, false) @@ -116,17 +111,14 @@ pub async fn transcribe_local(audio_bytes: Vec) -> Result { .and_then(|v| v.as_str().map(|s| s.to_string())) .ok_or_else(|| anyhow::anyhow!("Local Whisper model not configured"))?; - // Convert model ID to full path let model = super::whisper::get_model(&model_id) .ok_or_else(|| anyhow::anyhow!("Unknown model: {}", model_id))?; let model_path = model.local_path(); - // Get or initialize the transcriber let mut transcriber_lock = LOCAL_TRANSCRIBER .lock() .map_err(|e| anyhow::anyhow!("Failed to lock transcriber: {}", e))?; - // Check if we need to load/reload the transcriber let model_path_str = model_path.to_string_lossy().to_string(); let needs_reload = match transcriber_lock.as_ref() { None => true, @@ -145,25 +137,29 @@ pub async fn transcribe_local(audio_bytes: Vec) -> Result { *transcriber_lock = Some((model_path_str, transcriber)); } - // Transcribe the audio let (_, transcriber) = transcriber_lock.as_mut().unwrap(); - let text = transcriber - .transcribe(&audio_bytes) - .context("Transcription failed")?; + let text = transcriber.transcribe(&audio_bytes).map_err(|e| { + tracing::error!("Transcription failed: {}", e); + e + })?; Ok(text) }) .await - .context("Transcription task failed")? + .map_err(|e| { + tracing::error!("Transcription task failed: {}", e); + anyhow::anyhow!(e) + })? } fn build_api_client(provider: DictationProvider) -> Result { let config = Config::global(); let def = get_provider_def(provider); - let api_key = config - .get_secret(def.config_key) - .context(format!("{} not configured", def.config_key))?; + let api_key = config.get_secret(def.config_key).map_err(|e| { + tracing::error!("{} not configured: {}", def.config_key, e); + anyhow::anyhow!("{} not configured", def.config_key) + })?; let base_url = if let Some(host_key) = def.host_key { config @@ -185,7 +181,10 @@ fn build_api_client(provider: DictationProvider) -> Result { DictationProvider::Local => anyhow::bail!("Local provider should not use API client"), }; - ApiClient::with_timeout(base_url, auth, REQUEST_TIMEOUT).context("Failed to create API client") + ApiClient::with_timeout(base_url, auth, REQUEST_TIMEOUT).map_err(|e| { + tracing::error!("Failed to create API client: {}", e); + e + }) } pub async fn transcribe_with_provider( @@ -202,7 +201,10 @@ pub async fn transcribe_with_provider( let part = reqwest::multipart::Part::bytes(audio_bytes) .file_name(format!("audio.{}", extension)) .mime_str(mime_type) - .context("Failed to create multipart")?; + .map_err(|e| { + tracing::error!("Failed to create multipart: {}", e); + anyhow::anyhow!(e) + })?; let form = reqwest::multipart::Form::new() .part("file", part) @@ -212,7 +214,10 @@ pub async fn transcribe_with_provider( .request(None, def.endpoint_path) .multipart_post(form) .await - .context("Request failed")?; + .map_err(|e| { + tracing::error!("Request failed: {}", e); + e + })?; if !response.status().is_success() { let status = response.status(); @@ -229,7 +234,10 @@ pub async fn transcribe_with_provider( } } - let data: serde_json::Value = response.json().await.context("Failed to parse response")?; + let data: serde_json::Value = response.json().await.map_err(|e| { + tracing::error!("Failed to parse response: {}", e); + anyhow::anyhow!(e) + })?; let text = data["text"] .as_str() diff --git a/crates/goose/src/dictation/whisper.rs b/crates/goose/src/dictation/whisper.rs index d32a280bb35d..fe48a44f92f0 100644 --- a/crates/goose/src/dictation/whisper.rs +++ b/crates/goose/src/dictation/whisper.rs @@ -203,15 +203,21 @@ impl WhisperTranscriber { model_path: P, bundled_tokenizer: &str, ) -> Result { + tracing::debug!(model_id, "initializing whisper transcriber"); + let device = if let Ok(device) = Device::new_cuda(0) { + tracing::debug!("using CUDA device"); device } else if let Ok(device) = Device::new_metal(0) { + tracing::debug!("using Metal device"); device } else { + tracing::debug!("using CPU device"); Device::Cpu }; let model_path_ref = model_path.as_ref(); + tracing::debug!(path = %model_path_ref.display(), "loading model from path"); if !model_path_ref.exists() { anyhow::bail!("Model file not found: {}", model_path_ref.display()); @@ -220,6 +226,11 @@ impl WhisperTranscriber { let model = get_model(model_id).ok_or_else(|| anyhow::anyhow!("Unknown model: {}", model_id))?; let config = model.config(); + tracing::debug!( + num_mel_bins = config.num_mel_bins, + d_model = config.d_model, + "loaded model config" + ); let mel_bytes = match config.num_mel_bins { 80 => include_bytes!("whisper_data/melfilters.bytes").as_slice(), @@ -231,14 +242,19 @@ impl WhisperTranscriber { &mut &mel_bytes[..], &mut mel_filters, )?; + tracing::debug!(mel_filters_len = mel_filters.len(), "loaded mel filters"); + tracing::debug!("loading GGUF model weights"); let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( model_path_ref, &device, )?; let model = m::quantized_model::Whisper::load(&vb, config.clone())?; + tracing::debug!("model weights loaded successfully"); + tracing::debug!("loading tokenizer"); let tokenizer = Self::load_tokenizer(model_path_ref, Some(bundled_tokenizer))?; + tracing::debug!("tokenizer loaded successfully"); Ok(Self { model, @@ -280,10 +296,37 @@ impl WhisperTranscriber { } pub fn transcribe(&mut self, audio_data: &[u8]) -> Result { - let mel_tensor = self.prepare_audio_input(audio_data)?; - let (_, _, content_frames) = mel_tensor.dims3()?; + tracing::debug!(audio_bytes = audio_data.len(), "starting transcription"); + + if audio_data.is_empty() { + tracing::debug!("empty audio data received"); + return Ok(String::new()); + } + + let (mel_tensor, actual_content_frames) = self.prepare_audio_input(audio_data)?; + let (_, _, padded_frames) = mel_tensor.dims3()?; + + let content_frames = actual_content_frames.min(padded_frames); + let audio_duration_secs = (content_frames * 160) as f32 / 16000.0; + tracing::debug!( + actual_content_frames, + padded_frames, + content_frames, + audio_duration_secs, + "prepared mel spectrogram" + ); + + if content_frames == 0 { + tracing::debug!("no content frames in mel spectrogram"); + return Ok(String::new()); + } let num_segments = content_frames.div_ceil(N_FRAMES); + tracing::debug!( + num_segments, + n_frames = N_FRAMES, + "processing audio segments" + ); let mut all_text_tokens = Vec::new(); let mut seek = 0; @@ -292,21 +335,63 @@ impl WhisperTranscriber { while seek < content_frames { segment_num += 1; let segment_size = usize::min(content_frames - seek, N_FRAMES); + tracing::debug!(segment_num, segment_size, seek, "processing segment"); let segment_text_tokens = self.process_segment(&mel_tensor, seek, segment_size, segment_num, num_segments)?; + tracing::debug!( + tokens_in_segment = segment_text_tokens.len(), + "segment produced tokens" + ); all_text_tokens.extend(segment_text_tokens); seek += segment_size; } - self.decode_tokens(&all_text_tokens) + tracing::debug!( + total_tokens = all_text_tokens.len(), + "decoding tokens to text" + ); + + if all_text_tokens.is_empty() { + tracing::warn!( + audio_bytes = audio_data.len(), + audio_duration_secs, + num_segments, + "no tokens produced from audio - possible silence or unrecognized speech" + ); + return Ok(String::new()); + } + + let raw_result = self.decode_tokens(&all_text_tokens)?; + let result = deduplicate_text(&raw_result); + if result != raw_result { + tracing::debug!( + before_len = raw_result.len(), + after_len = result.len(), + "text-level deduplication removed repeated phrases" + ); + } + tracing::debug!(result_len = result.len(), "transcription complete"); + Ok(result) } - fn prepare_audio_input(&self, audio_data: &[u8]) -> Result { + fn prepare_audio_input(&self, audio_data: &[u8]) -> Result<(Tensor, usize)> { + tracing::debug!(audio_bytes = audio_data.len(), "decoding audio to PCM"); let pcm_data = decode_audio_simple(audio_data)?; + let pcm_samples = pcm_data.len(); + tracing::debug!(pcm_samples, "converting PCM to mel spectrogram"); + + let actual_content_frames = pcm_samples / 160; + let mel = audio::pcm_to_mel(&self.config, &pcm_data, &self.mel_filters); let mel_len = mel.len(); + tracing::debug!( + mel_len, + num_mel_bins = self.config.num_mel_bins, + actual_content_frames, + "creating mel tensor" + ); let mel_tensor = Tensor::from_vec( mel, ( @@ -317,7 +402,7 @@ impl WhisperTranscriber { &self.device, )?; - Ok(mel_tensor) + Ok((mel_tensor, actual_content_frames)) } fn process_segment( @@ -331,8 +416,25 @@ impl WhisperTranscriber { let _time_offset = (seek * 160) as f32 / 16000.0; // HOP_LENGTH = 160 let _segment_duration = (segment_size * 160) as f32 / 16000.0; let mel_segment = mel_tensor.narrow(2, seek, segment_size)?; + + if tracing::enabled!(tracing::Level::DEBUG) { + let mel_flat = mel_segment.flatten_all()?; + let mel_mean: f32 = mel_flat.mean(0)?.to_scalar()?; + let mel_max: f32 = mel_flat.max(0)?.to_scalar()?; + let mel_min: f32 = mel_flat.min(0)?.to_scalar()?; + tracing::debug!(mel_mean, mel_max, mel_min, "mel segment statistics"); + } + self.model.decoder.reset_kv_cache(); let audio_features = self.model.encoder.forward(&mel_segment, true)?; + + if tracing::enabled!(tracing::Level::DEBUG) { + let af_flat = audio_features.flatten_all()?; + let af_mean: f32 = af_flat.mean(0)?.to_scalar()?; + let af_max: f32 = af_flat.max(0)?.to_scalar()?; + let af_min: f32 = af_flat.min(0)?.to_scalar()?; + tracing::debug!(af_mean, af_max, af_min, "audio features statistics"); + } let suppress_tokens = { let mut suppress = vec![0f32; self.config.vocab_size]; for &token_id in &self.config.suppress_tokens { @@ -374,15 +476,44 @@ impl WhisperTranscriber { tokens.push(next_token); - if next_token == EOT_TOKEN || tokens.len() > self.config.max_target_positions { + if next_token == EOT_TOKEN { + tracing::debug!(tokens_generated = tokens.len() - 3, "EOT token received"); + break; + } + if tokens.len() > self.config.max_target_positions { + tracing::debug!("max target positions reached"); + break; + } + + if let Some(truncate_at) = self.detect_repetition(&tokens) { + tracing::debug!( + truncate_at, + tokens_before = tokens.len(), + "repetition detected, truncating" + ); + tokens.truncate(truncate_at); break; } } + + tracing::debug!( + all_tokens = ?&tokens[3..], + "all tokens generated in segment" + ); + let segment_text_tokens: Vec = tokens[3..] .iter() .filter(|&&t| t != EOT_TOKEN && t < TIMESTAMP_BEGIN) .copied() .collect(); + + if segment_text_tokens.is_empty() && tokens.len() > 3 { + tracing::debug!( + raw_tokens = ?&tokens[3..], + "no text tokens found after filtering (all were EOT or timestamps)" + ); + } + Ok(segment_text_tokens) } @@ -445,7 +576,7 @@ impl WhisperTranscriber { let penultimate_was_timestamp = if sampled_tokens.len() >= 2 { sampled_tokens[sampled_tokens.len() - 2] >= TIMESTAMP_BEGIN } else { - false + true }; if last_was_timestamp { @@ -477,11 +608,7 @@ impl WhisperTranscriber { .collect(); if !timestamp_tokens.is_empty() { - let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp { - *timestamp_tokens.last().unwrap() - } else { - timestamp_tokens.last().unwrap() + 1 - }; + let timestamp_last = timestamp_tokens.last().unwrap() + 1; for i in 0..vocab_size { mask_buffer[i as usize] = if i >= TIMESTAMP_BEGIN && i < timestamp_last { @@ -560,6 +687,12 @@ impl WhisperTranscriber { let max_text_token_logprob: f32 = text_log_probs.max(0)?.to_scalar::()?; + tracing::debug!( + timestamp_logprob, + max_text_token_logprob, + "timestamp vs text probability comparison" + ); + if timestamp_logprob > max_text_token_logprob { for i in 0..vocab_size { mask_buffer[i as usize] = if i < TIMESTAMP_BEGIN { @@ -580,9 +713,178 @@ impl WhisperTranscriber { .decode(tokens, true) .map_err(|e| anyhow::anyhow!("Failed to decode tokens: {}", e)) } + + fn detect_repetition(&self, tokens: &[u32]) -> Option { + detect_repetition_impl(tokens, SAMPLE_BEGIN, TIMESTAMP_BEGIN) + } +} + +/// Remove repeated phrases from transcribed text. +/// +/// Whisper models (especially smaller/quantized ones) tend to loop, producing output like +/// "I could build a record mode. I could build a record mode. I could build a record mode." +/// This function collapses adjacent duplicate sentences/phrases down to a single occurrence. +fn deduplicate_text(text: &str) -> String { + let trimmed = text.trim(); + if trimmed.is_empty() { + return String::new(); + } + + // Split into sentences on common boundaries (. ! ?) + let sentences = split_into_sentences(trimmed); + if sentences.len() <= 1 { + return trimmed.to_string(); + } + + let mut result: Vec<&str> = Vec::new(); + + let mut i = 0; + while i < sentences.len() { + // Try to find a repeating pattern starting at position i. + // Check pattern lengths from 1 sentence up to half the remaining sentences. + let remaining = sentences.len() - i; + let max_pattern_len = remaining / 2; + let mut best_pattern_len = 0; + let mut best_repeat_count = 0; + let mut best_total_consumed = 0; + + for pattern_len in 1..=max_pattern_len { + let pattern = &sentences[i..i + pattern_len]; + let mut count = 1; + let mut pos = i + pattern_len; + while pos + pattern_len <= sentences.len() { + let candidate = &sentences[pos..pos + pattern_len]; + if pattern + .iter() + .zip(candidate.iter()) + .all(|(a, b)| a.trim() == b.trim()) + { + count += 1; + pos += pattern_len; + } else { + break; + } + } + // Prefer the pattern that removes the most repeated sentences + let total_consumed = pattern_len * count; + if count >= 2 && total_consumed > best_total_consumed { + best_pattern_len = pattern_len; + best_repeat_count = count; + best_total_consumed = total_consumed; + } + } + + if best_repeat_count >= 2 { + // Keep only the first occurrence of the repeated pattern + for j in 0..best_pattern_len { + result.push(sentences[i + j]); + } + i += best_pattern_len * best_repeat_count; + } else { + result.push(sentences[i]); + i += 1; + } + } + + result.join("").trim_end().to_string() +} + +#[allow(clippy::string_slice)] // Splitting on ASCII punctuation; indices are always valid UTF-8 boundaries +fn split_into_sentences(text: &str) -> Vec<&str> { + let mut sentences = Vec::new(); + let mut last = 0; + let bytes = text.as_bytes(); + + for (i, &b) in bytes.iter().enumerate() { + if b == b'.' || b == b'!' || b == b'?' { + // Include trailing whitespace with the sentence + let mut end = i + 1; + while end < bytes.len() && bytes[end] == b' ' { + end += 1; + } + sentences.push(&text[last..end]); + last = end; + } + } + + // Don't forget the trailing fragment (if any) + if last < text.len() { + sentences.push(&text[last..]); + } + + sentences +} + +/// Detect repetition in token sequence, returning the index to truncate to if repetition found. +/// Filters out timestamp tokens (>= timestamp_begin) when looking for patterns. +/// Returns Some(truncate_index) if repetition detected, None otherwise. +fn detect_repetition_impl( + tokens: &[u32], + sample_begin: usize, + timestamp_begin: u32, +) -> Option { + if tokens.len() <= sample_begin { + return None; + } + + // Filter out timestamp tokens to get just text tokens, but remember original positions + let text_tokens: Vec<(usize, u32)> = tokens[sample_begin..] + .iter() + .enumerate() + .filter(|(_, &t)| t < timestamp_begin) + .map(|(i, &t)| (i + sample_begin, t)) + .collect(); + + // Need at least 3 tokens to detect any repetition (e.g., [A, A, A]) + if text_tokens.len() < 3 { + return None; + } + + let n = text_tokens.len(); + + // Try different pattern lengths, starting from 1 + for pattern_len in 1..=(n / 2) { + // Check if the last `pattern_len` tokens match the preceding `pattern_len` tokens + let pattern_start = n - pattern_len; + let prev_pattern_start = n - 2 * pattern_len; + + let matches = (0..pattern_len) + .all(|i| text_tokens[prev_pattern_start + i].1 == text_tokens[pattern_start + i].1); + + if !matches { + continue; + } + + // Found adjacent repeated pattern - count total repetitions + let mut repeat_count = 2; + let mut check_start = prev_pattern_start; + + while check_start >= pattern_len { + let earlier_start = check_start - pattern_len; + let still_matches = (0..pattern_len) + .all(|i| text_tokens[earlier_start + i].1 == text_tokens[pattern_start + i].1); + if still_matches { + repeat_count += 1; + check_start = earlier_start; + } else { + break; + } + } + + // Trigger on: 3+ repeats of anything, or 2 repeats of 5+ token patterns + if repeat_count >= 3 || (repeat_count >= 2 && pattern_len >= 5) { + // Return the original token position after the first pattern + let first_pattern_end_text_idx = check_start + pattern_len; + let truncate_pos = text_tokens[first_pattern_end_text_idx].0; + return Some(truncate_pos); + } + } + + None } fn decode_audio_simple(audio_data: &[u8]) -> Result> { + tracing::debug!(input_bytes = audio_data.len(), "decoding audio"); let audio_vec = audio_data.to_vec(); let cursor = Cursor::new(audio_vec); let mss = MediaSourceStream::new(Box::new(cursor), Default::default()); @@ -621,11 +923,14 @@ fn decode_audio_simple(audio_data: &[u8]) -> Result> { anyhow::bail!("No channel information in audio track (neither channels nor channel_layout)") }; + tracing::debug!(sample_rate, channels, "audio format detected"); + let mut decoder = symphonia::default::get_codecs() .make(&track.codec_params, &DecoderOptions::default()) .context("Failed to create audio decoder - please ensure browser sends WAV format audio")?; let mut pcm_data = Vec::new(); + let mut packet_count = 0; loop { let packet = match format.next_packet() { @@ -641,6 +946,7 @@ fn decode_audio_simple(audio_data: &[u8]) -> Result> { match decoder.decode(&packet) { Ok(decoded) => { pcm_data.extend(audio_buffer_to_f32(&decoded)); + packet_count += 1; } Err(symphonia::core::errors::Error::DecodeError(_)) => { continue; @@ -649,18 +955,44 @@ fn decode_audio_simple(audio_data: &[u8]) -> Result> { } } + tracing::debug!( + packet_count, + pcm_samples = pcm_data.len(), + "decoded audio packets" + ); + let mono_data = if channels > 1 { + tracing::debug!(channels, "converting to mono"); convert_to_mono(&pcm_data, channels) } else { pcm_data }; let resampled = if sample_rate != 16000 { + tracing::debug!(from_rate = sample_rate, to_rate = 16000, "resampling audio"); resample_audio(&mono_data, sample_rate, 16000)? } else { mono_data }; + if tracing::enabled!(tracing::Level::DEBUG) { + if !resampled.is_empty() { + let max_abs = resampled.iter().map(|s| s.abs()).fold(0.0f32, f32::max); + let mean_abs = resampled.iter().map(|s| s.abs()).sum::() / resampled.len() as f32; + let rms = + (resampled.iter().map(|s| s * s).sum::() / resampled.len() as f32).sqrt(); + tracing::debug!( + output_samples = resampled.len(), + max_abs, + mean_abs, + rms, + "audio decoding complete with PCM stats" + ); + } else { + tracing::debug!(output_samples = 0, "audio decoding complete (empty)"); + } + } + Ok(resampled) } @@ -734,6 +1066,13 @@ fn resample_audio(data: &[f32], from_rate: u32, to_rate: u32) -> Result return Ok(data.to_vec()); } + tracing::debug!( + from_rate, + to_rate, + input_samples = data.len(), + "resampling audio" + ); + let params = SincInterpolationParameters { sinc_len: 256, f_cutoff: 0.95, @@ -753,5 +1092,110 @@ fn resample_audio(data: &[f32], from_rate: u32, to_rate: u32) -> Result let waves_in = vec![data.to_vec()]; let waves_out = resampler.process(&waves_in, None)?; + tracing::debug!(output_samples = waves_out[0].len(), "resampling complete"); Ok(waves_out[0].clone()) } + +#[cfg(test)] +mod tests { + use super::*; + + use test_case::test_case; + + const TS: u32 = 50364; // A timestamp token for tests + + // detect_repetition_impl tests + // sample_begin=3 means tokens[0..3] are SOT, language, transcribe + // timestamp_begin=50364 means tokens >= 50364 are timestamps + + #[test_case(&[0, 1, 2, 10, 10, 10], Some(4) ; "single token repeated 3x")] + #[test_case(&[0, 1, 2, 10, 10], None ; "single token repeated 2x not enough")] + #[test_case(&[0, 1, 2, 10, 20, 30, 10, 20, 30], None ; "3-token pattern repeated 2x not enough")] + #[test_case(&[0, 1, 2, 10, 20, 30, 40, 50, 10, 20, 30, 40, 50], Some(8) ; "5-token pattern repeated 2x")] + #[test_case(&[0, 1, 2, 10, 20, 10, 20, 10, 20], Some(5) ; "2-token pattern repeated 3x")] + #[test_case(&[0, 1, 2, 10, 20, 30, 40, 10, 20, 30, 40], None ; "4-token pattern repeated 2x not enough")] + #[test_case(&[0, 1, 2, 10, 99, 20, 10, 99, 20], None ; "non-adjacent same tokens no trigger")] + fn test_detect_repetition_no_timestamps(tokens: &[u32], expected: Option) { + assert_eq!(detect_repetition_impl(tokens, 3, 50364), expected); + } + + #[test_case( + &[0, 1, 2, TS, 10, 20, 30, TS+1, TS+2, 10, 20, 30, TS+3], + None ; + "phrase 3 tokens with timestamps 2x not enough" + )] + #[test_case( + &[0, 1, 2, TS, 10, 10, 10, TS+1], + Some(5) ; + "single token 3x with surrounding timestamps" + )] + #[test_case( + &[0, 1, 2, TS, 10, 20, TS+1, TS+2, 10, 20, TS+3, TS+4, 10, 20, TS+5], + Some(8) ; + "2-token pattern 3x with timestamps interleaved" + )] + fn test_detect_repetition_with_timestamps(tokens: &[u32], expected: Option) { + assert_eq!(detect_repetition_impl(tokens, 3, 50364), expected); + } + + // Real example from logs: phrase repeated 3x with timestamps + #[test] + fn test_detect_repetition_real_example() { + let tokens: Vec = vec![ + 0, 1, 2, // SOT, lang, transcribe (indices 0-2) + 50364, 286, 500, 380, 458, 983, 309, 311, 18617, 2564, 13, // first phrase + ts + 50450, 50475, 286, 500, 380, 458, 983, 309, 311, 18617, 2564, + 13, // second phrase + ts + 50550, 50551, 286, 500, 380, 458, 983, 309, 311, 18617, 2564, + 13, // third phrase + ts + ]; + // Text tokens are: 286, 500, 380, 458, 983, 309, 311, 18617, 2564, 13 (10 tokens) + // Repeated 3 times, should trigger + let result = detect_repetition_impl(&tokens, 3, 50364); + assert!(result.is_some(), "Should detect repetition in real example"); + } + + #[test] + fn test_no_false_positive_on_dog_sentences() { + // "I saw a dog. I liked the dog. I gave the dog food." + // dog=100, other tokens are different + let tokens: Vec = vec![ + 0, 1, 2, 10, 11, 12, 100, 13, // I saw a dog. + 20, 21, 22, 100, 23, // I liked the dog. + 30, 31, 32, 100, 33, // I gave the dog food. + ]; + assert_eq!(detect_repetition_impl(&tokens, 3, 50364), None); + } + + // deduplicate_text tests + #[test_case("", "" ; "empty")] + #[test_case(" ", "" ; "whitespace")] + #[test_case("I went to the store. Then I came home.", "I went to the store. Then I came home." ; "no repetition")] + #[test_case( + "I could build a record mode. I could build a record mode. I could build a record mode.", + "I could build a record mode." ; + "single sentence 3x" + )] + #[test_case( + "Yeah I was thinking about that. Yeah I was thinking about that.", + "Yeah I was thinking about that." ; + "single sentence 2x" + )] + #[test_case( + "Who works for Flux? Who works for Flux? Who works for Flux?", + "Who works for Flux?" ; + "question marks" + )] + #[test_case("Stop! Stop! Stop!", "Stop!" ; "exclamation marks")] + #[test_case("hello hello hello hello", "hello hello hello hello" ; "no sentence boundaries")] + fn test_deduplicate_text(input: &str, expected: &str) { + assert_eq!(deduplicate_text(input), expected); + } + + #[test_case("Hello. World. Foo.", vec!["Hello. ", "World. ", "Foo."] ; "basic")] + #[test_case("Hello. World", vec!["Hello. ", "World"] ; "trailing fragment")] + #[test_case("Really? Yes! Ok.", vec!["Really? ", "Yes! ", "Ok."] ; "mixed punctuation")] + fn test_split_into_sentences(input: &str, expected: Vec<&str>) { + assert_eq!(split_into_sentences(input), expected); + } +}