Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
616 changes: 597 additions & 19 deletions src-tauri/Cargo.lock

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ reqwest = { version = "0.12", features = ["stream", "json"] }
futures-util = "0.3"
directories = "5.0"
symphonia = { version = "0.5", features = ["mp3", "wav", "flac"] }
image = "0.24"
ndarray = "0.16.1"
tokenizers = { version = "0.22.1" }
thiserror = "1.0"

# ONNX Runtime with platform-specific execution providers
[target.'cfg(target_os = "windows")'.dependencies]
ort = { version = "=2.0.0-rc.10", features = ["directml", "download-binaries"] }

[target.'cfg(target_os = "linux")'.dependencies]
ort = { version = "=2.0.0-rc.10", features = ["cuda", "download-binaries"] }

[target.'cfg(target_os = "macos")'.dependencies]
ort = { version = "=2.0.0-rc.10", features = ["coreml", "download-binaries"] }

[dev-dependencies]
image = "0.25"
Expand Down
3 changes: 2 additions & 1 deletion src-tauri/src/inference_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ extern "C" {
fn mtmd_bitmap_init_from_audio(n_samples: usize, data: *const c_float) -> *mut MtmdBitmap;
fn mtmd_bitmap_is_audio(bitmap: *const MtmdBitmap) -> bool;
fn mtmd_bitmap_free(bitmap: *mut MtmdBitmap);
fn mtmd_support_audio(ctx: *mut MtmdContext) -> bool;
fn mtmd_support_vision(ctx: *const MtmdContext) -> bool;
fn mtmd_support_audio(ctx: *const MtmdContext) -> bool;
fn mtmd_get_audio_bitrate(ctx: *mut MtmdContext) -> c_int;

fn mtmd_input_chunks_init() -> *mut MtmdInputChunks;
Expand Down
121 changes: 120 additions & 1 deletion src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@ mod llama_inference;
pub mod model_manager;
pub mod inference_engine;
mod audio_decoder;
mod vlm_onnx;

use model_manager::{ModelManager, DownloadProgress};
use inference_engine::{InferenceEngine, SharedInferenceEngine, create_shared_engine};
use audio_decoder::decode_audio_file;
use vlm_onnx::VlmOnnx;

// Download cancellation state
pub type DownloadCancellation = Arc<AtomicBool>;

// Shared ONNX engine
pub type SharedOnnxEngine = Arc<tokio::sync::Mutex<Option<VlmOnnx>>>;

// Chat message for conversation history
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
Expand Down Expand Up @@ -285,12 +290,122 @@ async fn generate_response_audio(
response.map_err(|e| e.to_string())
}

// Download ONNX model from HuggingFace
#[tauri::command]
async fn download_onnx_model(
repo: String,
quantization: String,
app: tauri::AppHandle,
cancellation: State<'_, DownloadCancellation>,
) -> Result<String, String> {
cancellation.store(false, Ordering::SeqCst);

let manager = ModelManager::new().map_err(|e| e.to_string())?;
let cancel_flag = cancellation.inner().clone();

let model_id = manager
.download_smolvlm_onnx(
&repo,
&quantization,
move |progress: DownloadProgress| {
let _ = app.emit("download-progress", &progress);
},
cancel_flag,
)
.await
.map_err(|e| e.to_string())?;

Ok(model_id)
}

// Load ONNX model
#[tauri::command]
async fn load_onnx_model(
model_id: String,
onnx_engine: State<'_, SharedOnnxEngine>,
) -> Result<(), String> {
println!("Loading ONNX model: {}", model_id);

let manager = ModelManager::new().map_err(|e| e.to_string())?;
let (vision_path, embed_path, decoder_path, tokenizer_path) = manager
.get_onnx_model_paths(&model_id)
.await
.map_err(|e| e.to_string())?;

println!("Vision: {:?}", vision_path);
println!("Embed: {:?}", embed_path);
println!("Decoder: {:?}", decoder_path);
println!("Tokenizer: {:?}", tokenizer_path);

// Load the ONNX model in a blocking task
let engine = tokio::task::spawn_blocking(move || {
VlmOnnx::new(&vision_path, &embed_path, &decoder_path, &tokenizer_path)
})
.await
.map_err(|e| format!("Task join error: {}", e))?
.map_err(|e| e.to_string())?;

// Store the engine
let mut engine_lock = onnx_engine.lock().await;
*engine_lock = Some(engine);

println!("ONNX model loaded successfully");
Ok(())
}

// Generate response with ONNX model
#[tauri::command]
async fn generate_onnx_response(
prompt: String,
image_data: Vec<u8>,
image_width: u32,
image_height: u32,
onnx_engine: State<'_, SharedOnnxEngine>,
) -> Result<String, String> {
println!("Generating ONNX response");

let mut engine_lock = onnx_engine.lock().await;
let mut engine_opt = engine_lock.take();

if engine_opt.is_none() {
return Err("ONNX model not loaded".to_string());
}

drop(engine_lock);

// Run inference in a blocking task
let (response, engine_instance) = tokio::task::spawn_blocking(move || {
let mut eng = engine_opt.take().unwrap();
let result = eng.generate(&prompt, &image_data, image_width, image_height);
(result, eng)
})
.await
.map_err(|e| format!("Task join error: {}", e))?;

// Put the engine back
let mut engine_lock = onnx_engine.lock().await;
*engine_lock = Some(engine_instance);

response.map_err(|e| e.to_string())
}

// Check if ONNX model is downloaded
#[tauri::command]
async fn is_onnx_model_downloaded(model_id: String) -> Result<bool, String> {
let manager = ModelManager::new().map_err(|e| e.to_string())?;
manager
.is_onnx_model_downloaded(&model_id)
.await
.map_err(|e| e.to_string())
}

#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
tauri::Builder::default()
.plugin(tauri_plugin_opener::init())
.plugin(tauri_plugin_dialog::init())
.manage(create_shared_engine())
.manage(Arc::new(tokio::sync::Mutex::new(None::<VlmOnnx>))) // ONNX engine
.manage(Arc::new(AtomicBool::new(false))) // Download cancellation flag
.invoke_handler(tauri::generate_handler![
greet,
Expand All @@ -304,7 +419,11 @@ pub fn run() {
load_model,
generate_response,
check_audio_support,
generate_response_audio
generate_response_audio,
download_onnx_model,
load_onnx_model,
generate_onnx_response,
is_onnx_model_downloaded
])
.setup(|app| {
// Create menu
Expand Down
Loading