-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Fix speech local #7181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix speech local #7181
Changes from all commits
4cd8ab4
05203c1
c5753a2
6820e08
efbbc87
0f14388
4e63d65
7b90808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,21 +1,18 @@ | ||
| use crate::config::Config; | ||
| use crate::dictation::whisper::LOCAL_WHISPER_MODEL_CONFIG_KEY; | ||
| use crate::providers::api_client::{ApiClient, AuthMethod}; | ||
| use anyhow::{Context, Result}; | ||
| use anyhow::Result; | ||
| use serde::{Deserialize, Serialize}; | ||
| use std::sync::Mutex; | ||
| use std::time::Duration; | ||
| use utoipa::ToSchema; | ||
|
|
||
| const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); | ||
|
|
||
| // Global lazy-initialized transcriber to reuse the loaded model | ||
| // Stores (model_path, transcriber) to detect when model changes | ||
| static LOCAL_TRANSCRIBER: once_cell::sync::Lazy< | ||
| Mutex<Option<(String, super::whisper::WhisperTranscriber)>>, | ||
| > = once_cell::sync::Lazy::new(|| Mutex::new(None)); | ||
|
|
||
| // Bundled tokenizer JSON (2.4MB) | ||
| const WHISPER_TOKENIZER_JSON: &str = include_str!("whisper_data/tokens.json"); | ||
|
|
||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ToSchema)] | ||
|
|
@@ -85,7 +82,7 @@ pub fn get_provider_def(provider: DictationProvider) -> &'static DictationProvid | |
| PROVIDERS | ||
| .iter() | ||
| .find(|def| def.provider == provider) | ||
| .unwrap() // Safe because all enum variants are in PROVIDERS | ||
| .unwrap() | ||
| } | ||
|
|
||
| pub fn is_configured(provider: DictationProvider) -> bool { | ||
|
|
@@ -106,27 +103,22 @@ pub fn is_configured(provider: DictationProvider) -> bool { | |
| } | ||
|
|
||
| pub async fn transcribe_local(audio_bytes: Vec<u8>) -> Result<String> { | ||
| // Run transcription in a blocking task to avoid blocking the async runtime | ||
| tokio::task::spawn_blocking(move || { | ||
| // Get model ID from config | ||
| let config = Config::global(); | ||
| let model_id = config | ||
| .get(LOCAL_WHISPER_MODEL_CONFIG_KEY, false) | ||
| .ok() | ||
| .and_then(|v| v.as_str().map(|s| s.to_string())) | ||
| .ok_or_else(|| anyhow::anyhow!("Local Whisper model not configured"))?; | ||
|
|
||
| // Convert model ID to full path | ||
| let model = super::whisper::get_model(&model_id) | ||
| .ok_or_else(|| anyhow::anyhow!("Unknown model: {}", model_id))?; | ||
| let model_path = model.local_path(); | ||
|
|
||
| // Get or initialize the transcriber | ||
| let mut transcriber_lock = LOCAL_TRANSCRIBER | ||
| .lock() | ||
| .map_err(|e| anyhow::anyhow!("Failed to lock transcriber: {}", e))?; | ||
|
|
||
| // Check if we need to load/reload the transcriber | ||
| let model_path_str = model_path.to_string_lossy().to_string(); | ||
| let needs_reload = match transcriber_lock.as_ref() { | ||
| None => true, | ||
|
|
@@ -145,25 +137,29 @@ pub async fn transcribe_local(audio_bytes: Vec<u8>) -> Result<String> { | |
| *transcriber_lock = Some((model_path_str, transcriber)); | ||
| } | ||
|
|
||
| // Transcribe the audio | ||
| let (_, transcriber) = transcriber_lock.as_mut().unwrap(); | ||
| let text = transcriber | ||
| .transcribe(&audio_bytes) | ||
| .context("Transcription failed")?; | ||
| let text = transcriber.transcribe(&audio_bytes).map_err(|e| { | ||
| tracing::error!("Transcription failed: {}", e); | ||
| e | ||
| })?; | ||
|
|
||
| Ok(text) | ||
| }) | ||
| .await | ||
| .context("Transcription task failed")? | ||
| .map_err(|e| { | ||
| tracing::error!("Transcription task failed: {}", e); | ||
| anyhow::anyhow!(e) | ||
| })? | ||
| } | ||
|
|
||
| fn build_api_client(provider: DictationProvider) -> Result<ApiClient> { | ||
| let config = Config::global(); | ||
| let def = get_provider_def(provider); | ||
|
|
||
| let api_key = config | ||
| .get_secret(def.config_key) | ||
| .context(format!("{} not configured", def.config_key))?; | ||
| let api_key = config.get_secret(def.config_key).map_err(|e| { | ||
| tracing::error!("{} not configured: {}", def.config_key, e); | ||
| anyhow::anyhow!("{} not configured", def.config_key) | ||
| })?; | ||
|
Comment on lines
+159
to
+162
|
||
|
|
||
| let base_url = if let Some(host_key) = def.host_key { | ||
| config | ||
|
|
@@ -185,7 +181,10 @@ fn build_api_client(provider: DictationProvider) -> Result<ApiClient> { | |
| DictationProvider::Local => anyhow::bail!("Local provider should not use API client"), | ||
| }; | ||
|
|
||
| ApiClient::with_timeout(base_url, auth, REQUEST_TIMEOUT).context("Failed to create API client") | ||
| ApiClient::with_timeout(base_url, auth, REQUEST_TIMEOUT).map_err(|e| { | ||
| tracing::error!("Failed to create API client: {}", e); | ||
| e | ||
| }) | ||
| } | ||
|
|
||
| pub async fn transcribe_with_provider( | ||
|
|
@@ -202,7 +201,10 @@ pub async fn transcribe_with_provider( | |
| let part = reqwest::multipart::Part::bytes(audio_bytes) | ||
| .file_name(format!("audio.{}", extension)) | ||
| .mime_str(mime_type) | ||
| .context("Failed to create multipart")?; | ||
| .map_err(|e| { | ||
| tracing::error!("Failed to create multipart: {}", e); | ||
| anyhow::anyhow!(e) | ||
| })?; | ||
|
Comment on lines
201
to
+207
|
||
|
|
||
| let form = reqwest::multipart::Form::new() | ||
| .part("file", part) | ||
|
|
@@ -212,7 +214,10 @@ pub async fn transcribe_with_provider( | |
| .request(None, def.endpoint_path) | ||
| .multipart_post(form) | ||
| .await | ||
| .context("Request failed")?; | ||
| .map_err(|e| { | ||
| tracing::error!("Request failed: {}", e); | ||
| e | ||
| })?; | ||
|
|
||
| if !response.status().is_success() { | ||
| let status = response.status(); | ||
|
|
@@ -229,7 +234,10 @@ pub async fn transcribe_with_provider( | |
| } | ||
| } | ||
|
|
||
| let data: serde_json::Value = response.json().await.context("Failed to parse response")?; | ||
| let data: serde_json::Value = response.json().await.map_err(|e| { | ||
| tracing::error!("Failed to parse response: {}", e); | ||
| anyhow::anyhow!(e) | ||
| })?; | ||
|
Comment on lines
+237
to
+240
|
||
|
|
||
| let text = data["text"] | ||
| .as_str() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing
anyhow::Contextwithmap_err+ logging drops useful error context for callers (and can lead to duplicate logging up the stack); prefer keeping.context("…")on these fallible calls and let the top-level handler decide if/where to log.