diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index f1edefc..e43e308 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -8,6 +8,20 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -313,7 +327,7 @@ dependencies = [ "anyhow", "arrayvec", "log", - "nom", + "nom 8.0.0", "num-rational", "v_frame", ] @@ -327,6 +341,12 @@ dependencies = [ "arrayvec", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.7" @@ -339,6 +359,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "baseweightcanvas" version = "0.1.0" @@ -346,9 +372,12 @@ dependencies = [ "anyhow", "directories", "futures-util", - "image", + "image 0.24.9", + "image 0.25.9", "libloading 0.8.9", + "ndarray", "once_cell", + "ort", "reqwest", "serde", "serde_json", @@ -358,6 +387,8 @@ dependencies = [ "tauri-plugin-dialog", "tauri-plugin-opener", "tempfile", + "thiserror 1.0.69", + "tokenizers", "tokio", ] @@ -558,6 +589,15 @@ dependencies = [ "toml 0.9.8", ] +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" version = "1.2.46" @@ -593,7 +633,7 @@ version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" dependencies = [ - "smallvec", + "smallvec 1.15.1", "target-lexicon", ] @@ -637,6 +677,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -646,6 +701,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.61.2", +] + [[package]] name = "convert_case" version = "0.4.0" @@ -802,7 +870,7 @@ dependencies = [ "phf 0.10.1", "proc-macro2", "quote", - "smallvec", + "smallvec 1.15.1", "syn 1.0.109", ] @@ -826,14 +894,38 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + [[package]] name = "darling" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.110", ] [[package]] @@ -850,17 +942,47 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn 2.0.110", +] + [[package]] name = "darling_macro" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ - "darling_core", + "darling_core 0.21.3", "quote", "syn 2.0.110", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.5.5" @@ -871,6 +993,37 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.110", +] + [[package]] name = "derive_more" version = "0.99.20" @@ -1065,6 +1218,12 @@ version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1148,6 +1307,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -1180,7 +1348,7 @@ dependencies = [ "lebe", "miniz_oxide", "rayon-core", - "smallvec", + "smallvec 1.15.1", "zune-inflate", ] @@ -1235,6 +1403,18 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + [[package]] name = "find-msvc-tools" version = "0.1.5" @@ -1555,6 +1735,16 @@ dependencies = [ "wasip2", ] +[[package]] +name = "gif" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae047235e33e2829703574b54fdec96bfbad892062d97fed2f76022287de61b" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "gif" version = "0.14.0" @@ -1580,7 +1770,7 @@ dependencies = [ "libc", "once_cell", "pin-project-lite", - "smallvec", + "smallvec 1.15.1", "thiserror 1.0.69", ] @@ -1616,7 +1806,7 @@ dependencies = [ "libc", "memchr", "once_cell", - "smallvec", + "smallvec 1.15.1", "thiserror 1.0.69", ] @@ -1854,7 +2044,7 @@ dependencies = [ "itoa", "pin-project-lite", "pin-utils", - "smallvec", + "smallvec 1.15.1", "tokio", "want", ] @@ -1987,7 +2177,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.1", "zerovec", ] @@ -2045,7 +2235,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.1", "utf8_iter", ] @@ -2059,6 +2249,24 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.24.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "exr", + "gif 0.13.3", + "jpeg-decoder", + "num-traits", + "png 0.17.16", + "qoi", + "tiff 0.9.1", +] + [[package]] name = "image" version = "0.25.9" @@ -2069,7 +2277,7 @@ dependencies = [ "byteorder-lite", "color_quant", "exr", - "gif", + "gif 0.14.0", "image-webp", "moxcms", "num-traits", @@ -2078,7 +2286,7 @@ dependencies = [ "ravif", "rayon", "rgb", - "tiff", + "tiff 0.10.3", "zune-core 0.5.0", "zune-jpeg 0.5.5", ] @@ -2122,6 +2330,19 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" +dependencies = [ + "console", + "portable-atomic", + "unicode-width", + "unit-prefix", + "web-time", +] + [[package]] name = "infer" version = "0.19.0" @@ -2247,6 +2468,15 @@ dependencies = [ "libc", ] +[[package]] +name = "jpeg-decoder" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00810f1d8b74be64b13dbf3db89ac67740615d6c891f0e7b6179326533011a07" +dependencies = [ + "rayon", +] + [[package]] name = "js-sys" version = "0.3.82" @@ -2382,6 +2612,7 @@ checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" dependencies = [ "bitflags 2.10.0", "libc", + "redox_syscall", ] [[package]] @@ -2426,6 +2657,22 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "markup5ever" version = "0.14.1" @@ -2457,6 +2704,16 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -2488,6 +2745,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2509,6 +2772,28 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "moxcms" version = "0.7.9" @@ -2557,6 +2842,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.9.0" @@ -2612,6 +2912,16 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nom" version = "8.0.0" @@ -2637,6 +2947,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2975,6 +3294,28 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags 2.10.0", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "open" version = "5.3.2" @@ -3047,6 +3388,31 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "ort" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec 2.0.0-alpha.10", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "pango" version = "0.18.3" @@ -3097,7 +3463,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.15.1", "windows-link 0.2.1", ] @@ -3119,6 +3485,15 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -3341,6 +3716,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -3681,6 +4071,12 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -3691,6 +4087,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools", + "rayon", +] + [[package]] name = "rayon-core" version = "1.13.0" @@ -4055,7 +4462,7 @@ dependencies = [ "phf_codegen 0.8.0", "precomputed-hash", "servo_arc", - "smallvec", + "smallvec 1.15.1", ] [[package]] @@ -4200,7 +4607,7 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08a72d8216842fdd57820dc78d840bef99248e35fb2554ff923319e60f2d686b" dependencies = [ - "darling", + "darling 0.21.3", "proc-macro2", "quote", "syn 2.0.110", @@ -4303,6 +4710,12 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "socket2" version = "0.6.1" @@ -4313,6 +4726,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "softbuffer" version = "0.4.6" @@ -4361,6 +4785,18 @@ dependencies = [ "system-deps", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom 7.1.3", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -4694,6 +5130,17 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "target-lexicon" version = "0.12.16" @@ -5059,6 +5506,17 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "tiff" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e" +dependencies = [ + "flate2", + "jpeg-decoder", + "weezl", +] + [[package]] name = "tiff" version = "0.10.3" @@ -5114,6 +5572,40 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tokenizers" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "indicatif", + "itertools", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.17", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.48.0" @@ -5445,18 +5937,75 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec 1.15.1", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.7" @@ -5675,7 +6224,7 @@ dependencies = [ "downcast-rs", "rustix", "scoped-tls", - "smallvec", + "smallvec 1.15.1", "wayland-sys", ] @@ -5735,6 +6284,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webkit2gtk" version = "2.0.1" @@ -5779,6 +6338,15 @@ dependencies = [ "system-deps", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webview2-com" version = "0.38.0" @@ -6438,6 +7006,16 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "y4m" version = "0.8.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 8ded6e1..743440e 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -31,6 +31,20 @@ reqwest = { version = "0.12", features = ["stream", "json"] } futures-util = "0.3" directories = "5.0" symphonia = { version = "0.5", features = ["mp3", "wav", "flac"] } +image = "0.24" +ndarray = "0.16.1" +tokenizers = { version = "0.22.1" } +thiserror = "1.0" + +# ONNX Runtime with platform-specific execution providers +[target.'cfg(target_os = "windows")'.dependencies] +ort = { version = "=2.0.0-rc.10", features = ["directml", "download-binaries"] } + +[target.'cfg(target_os = "linux")'.dependencies] +ort = { version = "=2.0.0-rc.10", features = ["cuda", "download-binaries"] } + +[target.'cfg(target_os = "macos")'.dependencies] +ort = { version = "=2.0.0-rc.10", features = ["coreml", "download-binaries"] } [dev-dependencies] image = "0.25" diff --git a/src-tauri/src/inference_engine.rs b/src-tauri/src/inference_engine.rs index 8e6a9de..d0fc1ce 100644 --- a/src-tauri/src/inference_engine.rs +++ b/src-tauri/src/inference_engine.rs @@ -255,7 +255,8 @@ extern "C" { fn mtmd_bitmap_init_from_audio(n_samples: usize, data: *const c_float) -> *mut MtmdBitmap; fn mtmd_bitmap_is_audio(bitmap: *const MtmdBitmap) -> bool; fn mtmd_bitmap_free(bitmap: *mut MtmdBitmap); - fn mtmd_support_audio(ctx: *mut MtmdContext) -> bool; + fn mtmd_support_vision(ctx: *const MtmdContext) -> bool; + fn mtmd_support_audio(ctx: *const MtmdContext) -> bool; fn mtmd_get_audio_bitrate(ctx: *mut MtmdContext) -> c_int; fn mtmd_input_chunks_init() -> *mut MtmdInputChunks; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index c74409e..a3969d5 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -10,14 +10,19 @@ mod llama_inference; pub mod model_manager; pub mod inference_engine; mod audio_decoder; +mod vlm_onnx; use model_manager::{ModelManager, DownloadProgress}; use inference_engine::{InferenceEngine, SharedInferenceEngine, create_shared_engine}; use audio_decoder::decode_audio_file; +use vlm_onnx::VlmOnnx; // Download cancellation state pub type DownloadCancellation = Arc; +// Shared ONNX engine +pub type SharedOnnxEngine = Arc>>; + // Chat message for conversation history #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { @@ -285,12 +290,122 @@ async fn generate_response_audio( response.map_err(|e| e.to_string()) } +// Download ONNX model from HuggingFace +#[tauri::command] +async fn download_onnx_model( + repo: String, + quantization: String, + app: tauri::AppHandle, + cancellation: State<'_, DownloadCancellation>, +) -> Result { + cancellation.store(false, Ordering::SeqCst); + + let manager = ModelManager::new().map_err(|e| e.to_string())?; + let cancel_flag = cancellation.inner().clone(); + + let model_id = manager + .download_smolvlm_onnx( + &repo, + &quantization, + move |progress: DownloadProgress| { + let _ = app.emit("download-progress", &progress); + }, + cancel_flag, + ) + .await + .map_err(|e| e.to_string())?; + + Ok(model_id) +} + +// Load ONNX model +#[tauri::command] +async fn load_onnx_model( + model_id: String, + onnx_engine: State<'_, SharedOnnxEngine>, +) -> Result<(), String> { + println!("Loading ONNX model: {}", model_id); + + let manager = ModelManager::new().map_err(|e| e.to_string())?; + let (vision_path, embed_path, decoder_path, tokenizer_path) = manager + .get_onnx_model_paths(&model_id) + .await + .map_err(|e| e.to_string())?; + + println!("Vision: {:?}", vision_path); + println!("Embed: {:?}", embed_path); + println!("Decoder: {:?}", decoder_path); + println!("Tokenizer: {:?}", tokenizer_path); + + // Load the ONNX model in a blocking task + let engine = tokio::task::spawn_blocking(move || { + VlmOnnx::new(&vision_path, &embed_path, &decoder_path, &tokenizer_path) + }) + .await + .map_err(|e| format!("Task join error: {}", e))? + .map_err(|e| e.to_string())?; + + // Store the engine + let mut engine_lock = onnx_engine.lock().await; + *engine_lock = Some(engine); + + println!("ONNX model loaded successfully"); + Ok(()) +} + +// Generate response with ONNX model +#[tauri::command] +async fn generate_onnx_response( + prompt: String, + image_data: Vec, + image_width: u32, + image_height: u32, + onnx_engine: State<'_, SharedOnnxEngine>, +) -> Result { + println!("Generating ONNX response"); + + let mut engine_lock = onnx_engine.lock().await; + let mut engine_opt = engine_lock.take(); + + if engine_opt.is_none() { + return Err("ONNX model not loaded".to_string()); + } + + drop(engine_lock); + + // Run inference in a blocking task + let (response, engine_instance) = tokio::task::spawn_blocking(move || { + let mut eng = engine_opt.take().unwrap(); + let result = eng.generate(&prompt, &image_data, image_width, image_height); + (result, eng) + }) + .await + .map_err(|e| format!("Task join error: {}", e))?; + + // Put the engine back + let mut engine_lock = onnx_engine.lock().await; + *engine_lock = Some(engine_instance); + + response.map_err(|e| e.to_string()) +} + +// Check if ONNX model is downloaded +#[tauri::command] +async fn is_onnx_model_downloaded(model_id: String) -> Result { + let manager = ModelManager::new().map_err(|e| e.to_string())?; + manager + .is_onnx_model_downloaded(&model_id) + .await + .map_err(|e| e.to_string()) +} + #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { tauri::Builder::default() .plugin(tauri_plugin_opener::init()) .plugin(tauri_plugin_dialog::init()) .manage(create_shared_engine()) + .manage(Arc::new(tokio::sync::Mutex::new(None::))) // ONNX engine .manage(Arc::new(AtomicBool::new(false))) // Download cancellation flag .invoke_handler(tauri::generate_handler![ greet, @@ -304,7 +419,11 @@ pub fn run() { load_model, generate_response, check_audio_support, - generate_response_audio + generate_response_audio, + download_onnx_model, + load_onnx_model, + generate_onnx_response, + is_onnx_model_downloaded ]) .setup(|app| { // Create menu diff --git a/src-tauri/src/llama_inference.rs b/src-tauri/src/llama_inference.rs index c7e9839..793fdf3 100644 --- a/src-tauri/src/llama_inference.rs +++ b/src-tauri/src/llama_inference.rs @@ -110,6 +110,9 @@ extern "C" { ) -> *mut MtmdContext; fn mtmd_free(ctx: *mut MtmdContext); + fn mtmd_support_vision(ctx: *const MtmdContext) -> bool; + fn mtmd_support_audio(ctx: *const MtmdContext) -> bool; + fn mtmd_bitmap_init(nx: u32, ny: u32, data: *const u8) -> *mut MtmdBitmap; fn mtmd_bitmap_free(bitmap: *mut MtmdBitmap); @@ -192,6 +195,44 @@ impl LlamaInference { // with proper batching, sampling, and token generation Ok(format!("Mock response to: {}", prompt)) } + + /// Check if a GGUF file supports vision or audio by loading it with mtmd + /// This is the proper way to detect multimodal capabilities, as used by llama-mtmd-cli + pub fn check_multimodal_support(model_path: &str, mmproj_path: &str) -> Result<(bool, bool)> { + unsafe { + // Load the main model temporarily + let model_path_c = CString::new(model_path)?; + let mut model_params = llama_model_default_params(); + model_params.n_gpu_layers = 0; // CPU only for quick check + model_params.use_mmap = true; + model_params.use_mlock = false; + + let model = llama_model_load_from_file(model_path_c.as_ptr(), model_params); + if model.is_null() { + return Err(anyhow!("Failed to load model from {}", model_path)); + } + + // Try to load the mmproj file + let mmproj_c = CString::new(mmproj_path)?; + let mtmd_params = mtmd_context_params_default(); + let mtmd = mtmd_init_from_file(mmproj_c.as_ptr(), model, mtmd_params); + + if mtmd.is_null() { + llama_model_free(model); + return Err(anyhow!("Failed to load mmproj from {}", mmproj_path)); + } + + // Check what the model supports + let supports_vision = mtmd_support_vision(mtmd); + let supports_audio = mtmd_support_audio(mtmd); + + // Clean up + mtmd_free(mtmd); + llama_model_free(model); + + Ok((supports_vision, supports_audio)) + } + } } impl Drop for LlamaInference { @@ -211,3 +252,192 @@ impl Drop for LlamaInference { } unsafe impl Send for LlamaInference {} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + // Helper to get test model paths + fn get_test_model_dir() -> PathBuf { + // Use the models directory from ModelManager + crate::model_manager::ModelManager::get_models_directory() + .unwrap_or_else(|_| PathBuf::from("test_models")) + } + + #[test] + #[ignore] // Requires actual model files + fn test_check_multimodal_support_vision_model() { + // Test with SmolVLM2 (separate mmproj file) + let model_dir = get_test_model_dir().join("smolvlm2-2.2b-instruct"); + + if !model_dir.exists() { + eprintln!("Skipping test - model directory not found: {:?}", model_dir); + return; + } + + let lang_file = model_dir.join("SmolVLM2-2.2B-Instruct-Q4_K_M.gguf"); + let mmproj_file = model_dir.join("mmproj-SmolVLM2-2.2B-Instruct-f16.gguf"); + + if !lang_file.exists() || !mmproj_file.exists() { + eprintln!("Skipping test - model files not found"); + return; + } + + let result = LlamaInference::check_multimodal_support( + lang_file.to_str().unwrap(), + mmproj_file.to_str().unwrap() + ); + + assert!(result.is_ok(), "Should successfully check model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_vision, "SmolVLM2 should support vision"); + assert!(!supports_audio, "SmolVLM2 should not support audio"); + } + + #[test] + #[ignore] // Requires actual model files + fn test_check_multimodal_support_unified_model() { + // Test with Ministral or other unified model + let model_dir = get_test_model_dir(); + + // Look for any unified model (vision encoder embedded in main file) + // This would be a model like Ministral-8B-Instruct-2410-Q4_K_M.gguf + let mut unified_model_path: Option = None; + + if let Ok(entries) = std::fs::read_dir(&model_dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + if let Ok(files) = std::fs::read_dir(&path) { + for file in files.flatten() { + let file_path = file.path(); + let file_name = file_path.file_name() + .and_then(|n| n.to_str()) + .unwrap_or(""); + + // Look for a single GGUF file (unified model) + if file_name.ends_with(".gguf") + && !file_name.contains("mmproj") + && !file_name.contains("vision") { + unified_model_path = Some(file_path); + break; + } + } + } + if unified_model_path.is_some() { + break; + } + } + } + } + + if let Some(model_path) = unified_model_path { + println!("Testing unified model: {:?}", model_path); + + // For unified models, mmproj_path is the same as model_path + let result = LlamaInference::check_multimodal_support( + model_path.to_str().unwrap(), + model_path.to_str().unwrap() + ); + + assert!(result.is_ok(), "Should successfully check unified model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!( + supports_vision || supports_audio, + "Unified model should support at least vision or audio" + ); + } else { + eprintln!("Skipping test - no unified model found"); + } + } + + #[test] + #[ignore] // Requires actual model files + fn test_check_multimodal_support_audio_model() { + // Test with Ultravox or other audio model + let model_dir = get_test_model_dir(); + + // Look for audio model directory + let audio_model_dirs = ["ultravox", "qwen2_audio"]; + let mut audio_model_path: Option = None; + let mut audio_mmproj_path: Option = None; + + for dir_name in audio_model_dirs { + let test_dir = model_dir.join(dir_name); + if test_dir.exists() { + if let Ok(entries) = std::fs::read_dir(&test_dir) { + for entry in entries.flatten() { + let path = entry.path(); + let file_name = path.file_name() + .and_then(|n| n.to_str()) + .unwrap_or(""); + + if file_name.ends_with(".gguf") { + if file_name.contains("mmproj") || file_name.contains("audio") { + audio_mmproj_path = Some(path); + } else { + audio_model_path = Some(path); + } + } + } + } + if audio_model_path.is_some() { + break; + } + } + } + + if let (Some(model_path), Some(mmproj_path)) = (audio_model_path.clone(), audio_mmproj_path.clone()) { + println!("Testing audio model: {:?}", model_path); + + let result = LlamaInference::check_multimodal_support( + model_path.to_str().unwrap(), + mmproj_path.to_str().unwrap() + ); + + assert!(result.is_ok(), "Should successfully check audio model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_audio, "Audio model should support audio"); + } else if let Some(model_path) = audio_model_path { + // Try as unified model + println!("Testing audio model (unified): {:?}", model_path); + + let result = LlamaInference::check_multimodal_support( + model_path.to_str().unwrap(), + model_path.to_str().unwrap() + ); + + assert!(result.is_ok(), "Should successfully check unified audio model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_audio, "Audio model should support audio"); + } else { + eprintln!("Skipping test - no audio model found"); + } + } + + #[test] + fn test_check_multimodal_support_nonexistent_files() { + let result = LlamaInference::check_multimodal_support( + "/nonexistent/model.gguf", + "/nonexistent/mmproj.gguf" + ); + + assert!(result.is_err(), "Should fail with nonexistent files"); + } + + #[test] + #[ignore] // Requires actual model files + fn test_check_multimodal_support_text_only_model() { + // This test would require a text-only GGUF model + // which should fail multimodal detection + let model_dir = get_test_model_dir(); + + // Look for a text-only model (no vision/audio support) + // These typically won't be in the models directory for this app + // but this test documents expected behavior + + eprintln!("Note: This test requires a text-only GGUF model"); + eprintln!("Text-only models should fail with 'Failed to load mmproj' error"); + } +} diff --git a/src-tauri/src/model_manager.rs b/src-tauri/src/model_manager.rs index 8d9ac57..ebff617 100644 --- a/src-tauri/src/model_manager.rs +++ b/src-tauri/src/model_manager.rs @@ -39,7 +39,7 @@ pub fn get_bundled_model_info() -> ModelInfo { } pub struct ModelManager { - models_dir: PathBuf, + pub models_dir: PathBuf, } impl ModelManager { @@ -98,9 +98,17 @@ impl ModelManager { if let Ok(file_type) = entry.file_type().await { if file_type.is_dir() { if let Some(model_id) = entry.file_name().to_str() { - // Check if this directory contains at least one GGUF file - if self.is_model_downloaded(model_id).await? { - models.push(model_id.to_string()); + // Check if this is an ONNX model (has _onnx_ in the name) + if model_id.contains("_onnx_") { + // For ONNX models, check if directory contains .onnx files + if self.is_onnx_model_downloaded(model_id).await? { + models.push(model_id.to_string()); + } + } else { + // For GGUF models, check if directory contains .gguf files + if self.is_model_downloaded(model_id).await? { + models.push(model_id.to_string()); + } } } } @@ -119,16 +127,23 @@ impl ModelManager { // Find GGUF files in directory let mut language_file: Option = None; - let mut mmproj_file: Option = None; + let mut mmproj_file_f16: Option = None; + let mut mmproj_file_other: Option = None; let mut entries = tokio::fs::read_dir(&model_dir).await?; while let Some(entry) = entries.next_entry().await? { let path = entry.path(); if let Some(filename) = path.file_name() { let filename_str = filename.to_string_lossy(); + let filename_upper = filename_str.to_uppercase(); if filename_str.ends_with(".gguf") { if filename_str.contains("mmproj") || filename_str.contains("vision") { - mmproj_file = Some(path.clone()); + // Prefer F16 mmproj files + if filename_upper.contains("F16") || filename_upper.contains("_F16") { + mmproj_file_f16 = Some(path.clone()); + } else if mmproj_file_other.is_none() { + mmproj_file_other = Some(path.clone()); + } } else if language_file.is_none() { language_file = Some(path.clone()); } @@ -136,6 +151,9 @@ impl ModelManager { } } + // Prefer F16 mmproj if available, otherwise use other quantization + let mmproj_file = mmproj_file_f16.or(mmproj_file_other); + match (language_file, mmproj_file) { (Some(lang), Some(mmproj)) => Ok((lang, mmproj)), (Some(lang), None) => { @@ -191,70 +209,57 @@ impl ModelManager { Ok(()) } - /// Verify that a model is multimodal by checking HuggingFace API metadata - async fn verify_multimodal_model(&self, repo: &str) -> Result<()> { - let api_url = format!("https://huggingface.co/api/models/{}", repo); - let client = reqwest::Client::new(); + /// Verify that downloaded model files are actually multimodal + /// This is done by loading them with mtmd and checking the GGUF metadata + pub async fn verify_downloaded_model(&self, model_id: &str) -> Result<(bool, bool)> { + use crate::llama_inference::LlamaInference; - #[derive(Deserialize)] - struct ModelInfo { - pipeline_tag: Option, - tags: Option>, - } + let model_dir = self.models_dir.join(model_id); - let response = client.get(&api_url).send().await?; + // Find the main model file and potential mmproj file + let mut language_file: Option = None; + let mut mmproj_file_f16: Option = None; + let mut mmproj_file_other: Option = None; - if !response.status().is_success() { - return Err(anyhow!("Failed to fetch model info from HuggingFace API")); + let mut entries = tokio::fs::read_dir(&model_dir).await?; + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + if let Some(filename) = path.file_name() { + let filename_str = filename.to_string_lossy(); + let filename_upper = filename_str.to_uppercase(); + if filename_str.ends_with(".gguf") { + if filename_str.contains("mmproj") || filename_str.contains("vision") { + // Prefer F16 mmproj files + if filename_upper.contains("F16") || filename_upper.contains("_F16") { + mmproj_file_f16 = Some(path.clone()); + } else if mmproj_file_other.is_none() { + mmproj_file_other = Some(path.clone()); + } + } else if language_file.is_none() { + language_file = Some(path.clone()); + } + } + } } - let model_info: ModelInfo = response.json().await?; - - // Check if it's a multimodal model - let is_multimodal = match model_info.pipeline_tag.as_deref() { - Some("image-text-to-text") => true, - Some("audio-text-to-text") => true, - Some("visual-question-answering") => true, - Some("video-text-to-text") => true, - _ => { - // Also check tags (case-insensitive) - let has_multimodal_tag = model_info.tags.as_ref().map_or(false, |tags| { - tags.iter().any(|tag| { - let tag_lower = tag.to_lowercase(); - tag_lower.contains("vision") || - tag_lower.contains("multimodal") || - tag_lower.contains("vlm") || - tag_lower.contains("audio") || - tag_lower.contains("video") || - tag_lower.contains("image-text") - }) - }); + let lang_path = language_file.ok_or_else(|| anyhow!("No GGUF model file found"))?; - // Also check repo name for known indicators - let repo_lower = repo.to_lowercase(); - let has_multimodal_name = repo_lower.contains("vlm") || - repo_lower.contains("vision") || - repo_lower.contains("video") || - repo_lower.contains("audio") || - repo_lower.contains("multimodal") || - repo_lower.contains("smolvlm") || - repo_lower.contains("pixtral") || - repo_lower.contains("ultravox") || - repo_lower.contains("moondream"); - - has_multimodal_tag || has_multimodal_name - } - }; + // Prefer F16 mmproj if available, otherwise use other quantization + let mmproj_file = mmproj_file_f16.or(mmproj_file_other); - if !is_multimodal { - return Err(anyhow!( - "Model '{}' does not appear to be a multimodal (vision/audio) model. \ - Baseweight Canvas requires vision-language or audio models with mmproj support.", - repo - )); - } + // Try mmproj file first, fall back to language file for unified models + let mmproj_path = mmproj_file.as_ref().unwrap_or(&lang_path); - Ok(()) + // Check multimodal support by actually loading the model + // Run in blocking task to avoid blocking the async runtime + let lang_path_str = lang_path.to_str().unwrap().to_string(); + let mmproj_path_str = mmproj_path.to_str().unwrap().to_string(); + + tokio::task::spawn_blocking(move || { + LlamaInference::check_multimodal_support(&lang_path_str, &mmproj_path_str) + }) + .await + .map_err(|e| anyhow!("Verification task failed: {}", e))? } async fn download_file( @@ -325,9 +330,6 @@ impl ModelManager { /// Validate and get model files from HuggingFace repository pub async fn validate_huggingface_repo(&self, repo: &str, quantization: &str) -> Result<(String, String, u64, u64)> { - // First, check if the model is actually multimodal via HF API - self.verify_multimodal_model(repo).await?; - // Try manifest API first (supports unified models) let manifest_url = format!("https://huggingface.co/v2/{}/manifests/latest", repo); let client = reqwest::Client::new(); @@ -350,8 +352,10 @@ impl ModelManager { if let Ok(manifest) = resp.json::().await { if let Some(language_file) = manifest.gguf_file { + // Strip any suffix like "(F16 mmproj)" from quantization string + let clean_quant = quantization.split('(').next().unwrap_or(quantization).trim(); // Check if this file matches the desired quantization - if language_file.to_uppercase().contains(&quantization.to_uppercase()) { + if language_file.to_uppercase().contains(&clean_quant.to_uppercase()) { let vision_file = manifest.mmproj_file.unwrap_or_else(|| language_file.clone()); // Get file sizes @@ -400,8 +404,12 @@ impl ModelManager { let files: Vec = response.json().await?; let mut language_file: Option<(String, u64)> = None; - let mut vision_file: Option<(String, u64)> = None; - let quant_upper = quantization.to_uppercase(); + let mut vision_file_f16: Option<(String, u64)> = None; + let mut vision_file_other: Option<(String, u64)> = None; + + // Strip any suffix like "(F16 mmproj)" from quantization string + let clean_quant = quantization.split('(').next().unwrap_or(quantization).trim(); + let quant_upper = clean_quant.to_uppercase(); for file in files { let path = file.path.to_lowercase(); @@ -410,7 +418,12 @@ impl ModelManager { if path.ends_with(".gguf") { if path.contains("mmproj") || path.contains("vision") { - vision_file = Some((file.path, size)); + // Prefer F16 mmproj files, but keep track of other quantizations + if path_upper.contains("F16") || path_upper.contains("_F16") { + vision_file_f16 = Some((file.path, size)); + } else if vision_file_other.is_none() { + vision_file_other = Some((file.path, size)); + } } else if language_file.is_none() && path_upper.contains(&quant_upper) { // Only select language files matching the desired quantization language_file = Some((file.path, size)); @@ -418,6 +431,9 @@ impl ModelManager { } } + // Prefer F16 mmproj if available, otherwise use other quantization + let vision_file = vision_file_f16.or(vision_file_other); + match (language_file, vision_file) { (Some((lang, lang_size)), Some((vis, vis_size))) => { Ok((lang, vis, lang_size, vis_size)) @@ -469,8 +485,138 @@ impl ModelManager { } } + println!("Model downloaded successfully"); + Ok(()) } + + /// Download SmolVLM ONNX model from HuggingFace + pub async fn download_smolvlm_onnx( + &self, + repo: &str, + quantization: &str, + progress_callback: F, + cancel_flag: Arc, + ) -> Result + where + F: Fn(DownloadProgress) + Send + 'static, + { + self.ensure_models_directory().await?; + + let model_id = format!("{}_onnx_{}", + repo.replace("/", "_").replace("-", "_").to_lowercase(), + quantization.to_lowercase() + ); + let model_dir = self.models_dir.join(&model_id); + fs::create_dir_all(&model_dir).await?; + + // ONNX files to download for SmolVLM + let quant_suffix = match quantization { + "Q4" | "q4" => "q4", + "Q8" | "q8" => "q8", + "FP16" | "fp16" => "fp16", + _ => "q4", // default to Q4 + }; + + // ONNX files are in onnx/ subdirectory, config/tokenizer at root + let onnx_files = vec![ + format!("vision_encoder_{}.onnx", quant_suffix), + format!("embed_tokens_{}.onnx", quant_suffix), + format!("decoder_model_merged_{}.onnx", quant_suffix), + ]; + + let root_files = vec![ + "config.json".to_string(), + "tokenizer.json".to_string(), + ]; + + // Download ONNX files from onnx/ subdirectory + for filename in onnx_files { + let file_path = model_dir.join(&filename); + + // Skip if already downloaded + if file_path.exists() { + progress_callback(DownloadProgress { + current: 100, + total: 100, + percentage: 100.0, + file: format!("{} (already downloaded)", filename), + }); + continue; + } + + let url = format!("https://huggingface.co/{}/resolve/main/onnx/{}", repo, filename); + self.download_file(&url, &file_path, &filename, &progress_callback, cancel_flag.clone()).await?; + } + + // Download config and tokenizer from root + for filename in root_files { + let file_path = model_dir.join(&filename); + + // Skip if already downloaded + if file_path.exists() { + progress_callback(DownloadProgress { + current: 100, + total: 100, + percentage: 100.0, + file: format!("{} (already downloaded)", filename), + }); + continue; + } + + let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename); + self.download_file(&url, &file_path, &filename, &progress_callback, cancel_flag.clone()).await?; + } + + Ok(model_id) + } + + /// Get ONNX model paths + pub async fn get_onnx_model_paths(&self, model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf, PathBuf)> { + let model_dir = self.models_dir.join(model_id); + + if !model_dir.exists() { + return Err(anyhow!("ONNX model directory not found: {}", model_id)); + } + + // Find ONNX files + let mut vision_file: Option = None; + let mut embed_file: Option = None; + let mut decoder_file: Option = None; + let mut tokenizer_file: Option = None; + + let mut entries = tokio::fs::read_dir(&model_dir).await?; + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + if let Some(filename) = path.file_name() { + let filename_str = filename.to_string_lossy(); + if filename_str.contains("vision_encoder") && filename_str.ends_with(".onnx") { + vision_file = Some(path.clone()); + } else if filename_str.contains("embed_tokens") && filename_str.ends_with(".onnx") { + embed_file = Some(path.clone()); + } else if filename_str.contains("decoder_model") && filename_str.ends_with(".onnx") { + decoder_file = Some(path.clone()); + } else if filename_str == "tokenizer.json" { + tokenizer_file = Some(path.clone()); + } + } + } + + match (vision_file, embed_file, decoder_file, tokenizer_file) { + (Some(vision), Some(embed), Some(decoder), Some(tokenizer)) => { + Ok((vision, embed, decoder, tokenizer)) + } + _ => Err(anyhow!("Missing ONNX model files in directory")), + } + } + + /// Check if ONNX model is downloaded + pub async fn is_onnx_model_downloaded(&self, model_id: &str) -> Result { + match self.get_onnx_model_paths(model_id).await { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } } #[cfg(test)] @@ -522,44 +668,13 @@ mod tests { // Test with real HuggingFace API - marked as #[ignore] so it doesn't run by default // Run with: cargo test -- --ignored --test-threads=1 - #[tokio::test] - #[ignore] - async fn test_verify_multimodal_model_valid_vision() { - let manager = ModelManager::new().unwrap(); - - // Test with SmolVLM2 (known vision model) - let result = manager.verify_multimodal_model("ggml-org/SmolVLM2-2.2B-Instruct-GGUF").await; - assert!(result.is_ok(), "SmolVLM2 should be recognized as multimodal"); - } - - #[tokio::test] - #[ignore] - async fn test_verify_multimodal_model_valid_audio() { - let manager = ModelManager::new().unwrap(); - - // Test with Ultravox (known audio model) - let result = manager.verify_multimodal_model("ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF").await; - assert!(result.is_ok(), "Ultravox should be recognized as multimodal"); - } - - #[tokio::test] - #[ignore] - async fn test_verify_multimodal_model_invalid_text_only() { - let manager = ModelManager::new().unwrap(); - - // Test with a text-only model (should fail) - let result = manager.verify_multimodal_model("meta-llama/Llama-2-7b-hf").await; - assert!(result.is_err(), "Text-only Llama model should be rejected"); - assert!(result.unwrap_err().to_string().contains("multimodal")); - } - #[tokio::test] #[ignore] async fn test_validate_huggingface_repo_valid() { let manager = ModelManager::new().unwrap(); // Test with SmolVLM2 - let result = manager.validate_huggingface_repo("ggml-org/SmolVLM2-2.2B-Instruct-GGUF").await; + let result = manager.validate_huggingface_repo("ggml-org/SmolVLM2-2.2B-Instruct-GGUF", "Q4_K_M").await; assert!(result.is_ok()); let (lang_file, vision_file, lang_size, vision_size) = result.unwrap(); @@ -574,7 +689,7 @@ mod tests { let manager = ModelManager::new().unwrap(); // Test with non-existent repo - let result = manager.validate_huggingface_repo("nonexistent/fake-model-12345").await; + let result = manager.validate_huggingface_repo("nonexistent/fake-model-12345", "Q4_K_M").await; assert!(result.is_err()); } @@ -584,7 +699,7 @@ mod tests { let manager = ModelManager::new().unwrap(); // Test scanning a real repo with separate mmproj file - let result = manager.scan_repo_for_gguf("ggml-org/SmolVLM2-2.2B-Instruct-GGUF").await; + let result = manager.scan_repo_for_gguf("ggml-org/SmolVLM2-2.2B-Instruct-GGUF", "Q4_K_M").await; assert!(result.is_ok()); let (lang_file, vision_file, lang_size, vision_size) = result.unwrap(); @@ -634,4 +749,167 @@ mod tests { assert_eq!(progress.total, 1000); assert_eq!(progress.percentage, 10.0); } + + // Tests for verify_downloaded_model + #[tokio::test] + #[ignore] // Requires actual downloaded model + async fn test_verify_downloaded_model_vision() { + let manager = ModelManager::new().unwrap(); + + // Test with SmolVLM2 if it exists + let model_id = "smolvlm2-2.2b-instruct"; + let model_dir = manager.models_dir.join(model_id); + + if !model_dir.exists() { + eprintln!("Skipping test - model not downloaded: {}", model_id); + return; + } + + let result = manager.verify_downloaded_model(model_id).await; + + assert!(result.is_ok(), "Should successfully verify SmolVLM2"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_vision, "SmolVLM2 should support vision"); + assert!(!supports_audio, "SmolVLM2 should not support audio"); + } + + #[tokio::test] + #[ignore] // Requires actual downloaded model + async fn test_verify_downloaded_model_unified() { + let manager = ModelManager::new().unwrap(); + + // Look for any unified model (single GGUF file with embedded vision) + let mut unified_model_id: Option = None; + + if let Ok(entries) = tokio::fs::read_dir(&manager.models_dir).await { + let mut entries = entries; + while let Ok(Some(entry)) = entries.next_entry().await { + if let Ok(file_type) = entry.file_type().await { + if file_type.is_dir() { + if let Some(model_id) = entry.file_name().to_str() { + let model_path = manager.models_dir.join(model_id); + + // Check if this directory has a single GGUF file (unified model) + let mut has_main_gguf = false; + let mut has_mmproj = false; + + if let Ok(files) = tokio::fs::read_dir(&model_path).await { + let mut files = files; + while let Ok(Some(file)) = files.next_entry().await { + if let Some(filename) = file.file_name().to_str() { + if filename.ends_with(".gguf") { + if filename.contains("mmproj") || filename.contains("vision") { + has_mmproj = true; + } else { + has_main_gguf = true; + } + } + } + } + } + + // Unified model has main GGUF but no separate mmproj + if has_main_gguf && !has_mmproj { + unified_model_id = Some(model_id.to_string()); + break; + } + } + } + } + } + } + + if let Some(model_id) = unified_model_id { + println!("Testing unified model: {}", model_id); + + let result = manager.verify_downloaded_model(&model_id).await; + + assert!(result.is_ok(), "Should successfully verify unified model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!( + supports_vision || supports_audio, + "Unified model should support at least vision or audio" + ); + } else { + eprintln!("Skipping test - no unified model found"); + } + } + + #[tokio::test] + async fn test_verify_downloaded_model_nonexistent() { + let manager = ModelManager::new().unwrap(); + + let result = manager.verify_downloaded_model("nonexistent-model-xyz-123").await; + + assert!(result.is_err(), "Should fail for nonexistent model"); + } + + #[tokio::test] + #[ignore] // Requires actual downloaded model + async fn test_verify_downloaded_model_audio() { + let manager = ModelManager::new().unwrap(); + + // Look for audio models + let audio_model_patterns = ["ultravox", "qwen2_audio"]; + let mut audio_model_id: Option = None; + + if let Ok(entries) = tokio::fs::read_dir(&manager.models_dir).await { + let mut entries = entries; + while let Ok(Some(entry)) = entries.next_entry().await { + if let Ok(file_type) = entry.file_type().await { + if file_type.is_dir() { + if let Some(model_id) = entry.file_name().to_str() { + let model_id_lower = model_id.to_lowercase(); + if audio_model_patterns.iter().any(|p| model_id_lower.contains(p)) { + audio_model_id = Some(model_id.to_string()); + break; + } + } + } + } + } + } + + if let Some(model_id) = audio_model_id { + println!("Testing audio model: {}", model_id); + + let result = manager.verify_downloaded_model(&model_id).await; + + assert!(result.is_ok(), "Should successfully verify audio model"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_audio, "Audio model should support audio"); + } else { + eprintln!("Skipping test - no audio model found"); + } + } + + // Integration test for get_model_paths with multimodal detection + #[tokio::test] + #[ignore] // Requires actual downloaded model + async fn test_get_model_paths_with_verification() { + let manager = ModelManager::new().unwrap(); + + // Test with SmolVLM2 if it exists + let model_id = "smolvlm2-2.2b-instruct"; + let model_dir = manager.models_dir.join(model_id); + + if !model_dir.exists() { + eprintln!("Skipping test - model not downloaded: {}", model_id); + return; + } + + // First verify the model is multimodal + let verify_result = manager.verify_downloaded_model(model_id).await; + assert!(verify_result.is_ok(), "Model should pass verification"); + let (supports_vision, _) = verify_result.unwrap(); + assert!(supports_vision, "SmolVLM2 should support vision"); + + // Then get model paths + let paths_result = manager.get_model_paths(model_id).await; + assert!(paths_result.is_ok(), "Should get model paths"); + + let (lang_path, mmproj_path) = paths_result.unwrap(); + assert!(lang_path.exists(), "Language model file should exist"); + assert!(mmproj_path.exists(), "MMProj file should exist"); + } } diff --git a/src-tauri/src/vlm_onnx.rs b/src-tauri/src/vlm_onnx.rs new file mode 100644 index 0000000..2fbaab2 --- /dev/null +++ b/src-tauri/src/vlm_onnx.rs @@ -0,0 +1,331 @@ +// VLM ONNX inference engine for Baseweight Canvas +// Supports ONNX-based Vision-Language Models +// Initially focused on SmolVLM, expandable to other VLMs + +use anyhow::{Result, Context}; +use image::DynamicImage; +use ndarray::{Array, Array4, Array5, s}; +use ort::{ + session::{Session, builder::GraphOptimizationLevel}, + value::Value, +}; +use std::collections::HashMap; +use std::path::Path; +use tokenizers::Tokenizer; + +#[cfg(target_os = "windows")] +use ort::execution_providers::DirectMLExecutionProvider; + +#[cfg(target_os = "linux")] +use ort::execution_providers::CUDAExecutionProvider; + +#[cfg(target_os = "macos")] +use ort::execution_providers::CoreMLExecutionProvider; + +// SmolVLM configuration +struct SmolVLMConfig { + num_key_value_heads: usize, + head_dim: usize, + num_hidden_layers: usize, + eos_token_id: u32, + image_token_id: u32, + max_context_length: usize, +} + +pub struct VlmOnnx { + vision_session: Session, + embed_session: Session, + decoder_session: Session, + tokenizer: Tokenizer, + config: SmolVLMConfig, +} + +impl VlmOnnx { + pub fn new( + vision_model_path: &Path, + embed_model_path: &Path, + decoder_model_path: &Path, + tokenizer_path: &Path, + ) -> Result { + println!("Loading SmolVLM ONNX models..."); + + // Create sessions with platform-specific execution providers + #[cfg(target_os = "windows")] + let vision_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([DirectMLExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(vision_model_path) + .context("Failed to create vision session")?; + + #[cfg(target_os = "linux")] + let vision_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(vision_model_path) + .context("Failed to create vision session")?; + + #[cfg(target_os = "macos")] + let vision_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + // Using CPU only for now - CoreML has compatibility issues + .commit_from_file(vision_model_path) + .context("Failed to create vision session")?; + + #[cfg(target_os = "windows")] + let embed_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([DirectMLExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(embed_model_path) + .context("Failed to create embed session")?; + + #[cfg(target_os = "linux")] + let embed_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(embed_model_path) + .context("Failed to create embed session")?; + + #[cfg(target_os = "macos")] + let embed_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_file(embed_model_path) + .context("Failed to create embed session")?; + + #[cfg(target_os = "windows")] + let decoder_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([DirectMLExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(decoder_model_path) + .context("Failed to create decoder session")?; + + #[cfg(target_os = "linux")] + let decoder_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])? + .commit_from_file(decoder_model_path) + .context("Failed to create decoder session")?; + + #[cfg(target_os = "macos")] + let decoder_session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_file(decoder_model_path) + .context("Failed to create decoder session")?; + + let tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + + // Get special tokens + let image_token_id = tokenizer.token_to_id("") + .ok_or_else(|| anyhow::anyhow!("Failed to get image token ID"))?; + let eos_token_id = 2; // SmolVLM2 EOS token + + // Auto-detect configuration from decoder + let decoder_inputs = &decoder_session.inputs; + let mut max_layer = 0; + let detected_kv_heads = 5; + + for input in decoder_inputs { + if input.name.starts_with("past_key_values.") { + if let Some(layer_str) = input.name.split('.').nth(1) { + if let Ok(layer_num) = layer_str.parse::() { + max_layer = max_layer.max(layer_num); + } + } + } + } + + let num_hidden_layers = max_layer + 1; + println!("Auto-detected: {} layers, {} kv heads", num_hidden_layers, detected_kv_heads); + + let config = SmolVLMConfig { + num_key_value_heads: detected_kv_heads, + head_dim: 64, + num_hidden_layers, + eos_token_id, + image_token_id, + max_context_length: 2048, + }; + + Ok(Self { + vision_session, + embed_session, + decoder_session, + tokenizer, + config, + }) + } + + pub fn generate(&mut self, prompt: &str, image_data: &[u8], width: u32, height: u32) -> Result { + // Load and preprocess image + let image = image::load_from_memory(image_data)?; + let (processed_image, pixel_attention_mask) = preprocess_image(image)?; + + // Expand prompt with image tokens + let expanded_prompt = expand_prompt_with_image(prompt); + + // Tokenize + let encoding = self.tokenizer.encode(expanded_prompt, true) + .map_err(|e| anyhow::anyhow!("Tokenization error: {}", e))?; + let input_ids = encoding.get_ids().iter().map(|&x| x as i64).collect::>(); + let attention_mask = encoding.get_attention_mask().iter().map(|&x| x as i64).collect::>(); + + // Initialize past key values + let batch_size = 1; + let mut past_key_values: HashMap> = HashMap::new(); + for layer in 0..self.config.num_hidden_layers { + let key_array = Array::zeros((batch_size, self.config.num_key_value_heads, 0, self.config.head_dim)).into_dyn(); + let value_array = Array::zeros((batch_size, self.config.num_key_value_heads, 0, self.config.head_dim)).into_dyn(); + past_key_values.insert(format!("past_key_values.{}.key", layer), key_array); + past_key_values.insert(format!("past_key_values.{}.value", layer), value_array); + } + + // Get image features + let mut vision_inputs: HashMap<&str, Value> = HashMap::new(); + vision_inputs.insert("pixel_values", Value::from_array(processed_image.clone())?.into()); + + let pixel_attention_mask_bool = pixel_attention_mask.map(|&x| x != 0); + vision_inputs.insert("pixel_attention_mask", Value::from_array(pixel_attention_mask_bool)?.into()); + + let vision_outputs = self.vision_session.run(vision_inputs)?; + let image_features = vision_outputs[0].try_extract_array::()?.to_owned(); + + // Reshape features + let total_size = image_features.shape()[0] * image_features.shape()[1]; + let feature_dim = image_features.shape()[2]; + let image_features_reshaped = image_features.into_shape_with_order((total_size, feature_dim))?; + + // Generation loop + let max_new_tokens = 1024; + let mut generated_tokens = Vec::new(); + let mut input_ids = Array::from_vec(input_ids) + .into_shape_with_order((1, encoding.get_ids().len()))? + .into_owned(); + let mut attention_mask = Array::from_vec(attention_mask) + .into_shape_with_order((1, encoding.get_attention_mask().len()))? + .into_owned(); + + // Calculate position_ids + let mut position_ids_vec = Vec::new(); + let mut cumsum = 0i64; + for &mask_val in attention_mask.iter() { + cumsum += mask_val; + position_ids_vec.push(cumsum); + } + let mut position_ids = Array::from_vec(position_ids_vec) + .into_shape_with_order((1, attention_mask.len()))? + .into_owned(); + + for _ in 0..max_new_tokens { + // Get input embeddings + let mut embed_inputs: HashMap<&str, Value> = HashMap::new(); + embed_inputs.insert("input_ids", Value::from_array(input_ids.clone())?.into()); + let embed_outputs = self.embed_session.run(embed_inputs)?; + let mut input_embeds = embed_outputs[0].try_extract_array::()?.to_owned(); + + // Replace image token embeddings + let mut feature_idx = 0; + for i in 0..input_ids.shape()[1] { + if input_ids[[0, i]] == self.config.image_token_id as i64 { + let mut slice = input_embeds.slice_mut(s![0, i, ..]); + slice.assign(&image_features_reshaped.slice(s![feature_idx, ..])); + feature_idx += 1; + } + } + + // Run decoder + let mut decoder_inputs: HashMap<&str, Value> = HashMap::new(); + decoder_inputs.insert("inputs_embeds", Value::from_array(input_embeds.clone())?.into()); + decoder_inputs.insert("attention_mask", Value::from_array(attention_mask.clone())?.into()); + decoder_inputs.insert("position_ids", Value::from_array(position_ids.clone())?.into()); + + for (key, value) in &past_key_values { + decoder_inputs.insert(key, Value::from_array(value.clone())?.into()); + } + + let decoder_outputs = self.decoder_session.run(decoder_inputs)?; + let logits = decoder_outputs[0].try_extract_array::()?.to_owned(); + + // Get next token (argmax) + let last_idx = logits.shape()[1] - 1; + let logits_slice = logits.slice(s![0, last_idx, ..]); + let next_token = logits_slice + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i as i64) + .unwrap(); + + // Update for next iteration + input_ids = Array::from_vec(vec![next_token]).into_shape_with_order((1, 1))?; + + let new_attention = Array::::ones((1, 1)); + let total_length = attention_mask.len() + new_attention.len(); + let mut attention_vec = attention_mask.iter().copied().collect::>(); + attention_vec.extend(new_attention.iter().copied()); + attention_mask = Array::from_vec(attention_vec).into_shape_with_order((1, total_length))?; + + let current_pos = position_ids[[0, position_ids.shape()[1] - 1]] + 1; + position_ids = Array::from_vec(vec![current_pos]).into_shape_with_order((1, 1))?; + + // Update past key values + for i in 0..self.config.num_hidden_layers { + let key = format!("past_key_values.{}.key", i); + let value = format!("past_key_values.{}.value", i); + + if let Some(past_key) = past_key_values.get_mut(&key) { + if i * 2 + 1 < decoder_outputs.len() { + let present_key = decoder_outputs[i * 2 + 1].try_extract_array::()?.to_owned(); + *past_key = present_key.into_dyn(); + } + } + + if let Some(past_value) = past_key_values.get_mut(&value) { + if i * 2 + 2 < decoder_outputs.len() { + let present_value = decoder_outputs[i * 2 + 2].try_extract_array::()?.to_owned(); + *past_value = present_value.into_dyn(); + } + } + } + + generated_tokens.push(next_token as u32); + + // Check for EOS + if next_token == self.config.eos_token_id as i64 { + break; + } + } + + // Decode + let generated_text = self.tokenizer.decode(&generated_tokens, true) + .map_err(|e| anyhow::anyhow!("Decoding error: {}", e))?; + Ok(generated_text) + } +} + +// Image preprocessing for SmolVLM +fn preprocess_image(image: DynamicImage) -> Result<(Array5, Array4)> { + // Simple preprocessing - resize to 512x512 for now + let image = image.resize_exact(512, 512, image::imageops::FilterType::Lanczos3); + let image = image.to_rgb8(); + + let mut pixel_values = Array5::::zeros((1, 1, 3, 512, 512)); + + for (x, y, pixel) in image.enumerate_pixels() { + for c in 0..3 { + let val = pixel[c] as f32 / 255.0; + let normalized = (val - 0.5) / 0.5; + pixel_values[[0, 0, c, y as usize, x as usize]] = normalized; + } + } + + let attention_mask = Array4::ones((1, 1, 512, 512)); + + Ok((pixel_values, attention_mask)) +} + +// Expand prompt with image tokens (simplified for now) +fn expand_prompt_with_image(prompt: &str) -> String { + // For now, just repeat token 64 times as per SmolVLM2 + let image_tokens = "".repeat(64); + prompt.replace("", &image_tokens) +} diff --git a/src-tauri/tests/model_manager_test.rs b/src-tauri/tests/model_manager_test.rs index 50e1160..c75322a 100644 --- a/src-tauri/tests/model_manager_test.rs +++ b/src-tauri/tests/model_manager_test.rs @@ -138,7 +138,7 @@ async fn integration_test_validate_valid_repo() { let manager = ModelManager::new().unwrap(); // Test validation of SmolVLM2 repo via the public API - let result = manager.validate_huggingface_repo("ggml-org/SmolVLM2-2.2B-Instruct-GGUF").await; + let result = manager.validate_huggingface_repo("ggml-org/SmolVLM2-2.2B-Instruct-GGUF", "Q4_K_M").await; assert!(result.is_ok(), "Should validate SmolVLM2 repo"); let (lang_file, vision_file, lang_size, vision_size) = result.unwrap(); @@ -161,7 +161,7 @@ async fn integration_test_validate_invalid_repo() { let manager = ModelManager::new().unwrap(); // Test with completely fake repo - let result = manager.validate_huggingface_repo("totally-fake/nonexistent-repo-xyz-12345").await; + let result = manager.validate_huggingface_repo("totally-fake/nonexistent-repo-xyz-12345", "Q4_K_M").await; assert!(result.is_err(), "Should fail on non-existent repo"); } @@ -174,3 +174,183 @@ async fn integration_test_download_bundled_model() { let is_downloaded = manager.is_model_downloaded("smolvlm2-2.2b-instruct").await; assert!(is_downloaded.is_ok(), "Should check bundled model status"); } + +// ===== Tests for new multimodal verification functionality ===== + +#[tokio::test] +#[ignore] // Requires downloading a small model +async fn integration_test_download_with_multimodal_verification_vision() { + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + + let temp_dir = TempDir::new().unwrap(); + let manager = ModelManager::with_models_dir(temp_dir.path().to_path_buf()); + + // Download a small vision model (SmolVLM2 with Q4_K_M quantization) + // This is about 1.4GB total, so test with caution + let cancel_flag = Arc::new(AtomicBool::new(false)); + let progress_callback = |progress: baseweightcanvas_lib::model_manager::DownloadProgress| { + println!("Download progress: {:.1}% - {}", + progress.percentage, + progress.file + ); + }; + + let result = manager.download_from_huggingface( + "ggml-org/SmolVLM2-2.2B-Instruct-GGUF", + "Q4_K_M", + progress_callback, + cancel_flag, + ).await; + + assert!(result.is_ok(), "Should successfully download and verify vision model"); + + // Verify the model was actually checked + let model_id = "ggml_org_smolvlm2_2_2b_instruct_gguf"; + let verify_result = manager.verify_downloaded_model(model_id).await; + + assert!(verify_result.is_ok(), "Should verify downloaded model"); + let (supports_vision, supports_audio) = verify_result.unwrap(); + assert!(supports_vision, "Downloaded model should support vision"); + assert!(!supports_audio, "SmolVLM2 should not support audio"); +} + +#[tokio::test] +#[ignore] // Requires actual downloaded model +async fn integration_test_verify_existing_vision_model() { + let manager = ModelManager::new().unwrap(); + + // Test with SmolVLM2 if it's already downloaded + let model_id = "smolvlm2-2.2b-instruct"; + let model_dir = manager.models_dir.join(model_id); + + if !model_dir.exists() { + eprintln!("Skipping test - SmolVLM2 not downloaded"); + return; + } + + let result = manager.verify_downloaded_model(model_id).await; + + assert!(result.is_ok(), "Should verify SmolVLM2"); + let (supports_vision, supports_audio) = result.unwrap(); + assert!(supports_vision, "SmolVLM2 should support vision"); + assert!(!supports_audio, "SmolVLM2 should not support audio"); +} + +#[tokio::test] +#[ignore] // Requires actual downloaded model +async fn integration_test_verify_unified_model() { + let manager = ModelManager::new().unwrap(); + + // Look for any downloaded unified model + let models = manager.list_downloaded_models().await.unwrap(); + + for model_id in models { + let model_dir = manager.models_dir.join(&model_id); + + // Check if it's a unified model (single GGUF, no separate mmproj) + let mut has_main_gguf = false; + let mut has_mmproj = false; + + if let Ok(entries) = std::fs::read_dir(&model_dir) { + for entry in entries.flatten() { + if let Some(filename) = entry.file_name().to_str() { + if filename.ends_with(".gguf") { + if filename.contains("mmproj") || filename.contains("vision") { + has_mmproj = true; + } else { + has_main_gguf = true; + } + } + } + } + } + + if has_main_gguf && !has_mmproj { + println!("Testing unified model: {}", model_id); + + let result = manager.verify_downloaded_model(&model_id).await; + + assert!(result.is_ok(), "Should verify unified model: {}", model_id); + let (supports_vision, supports_audio) = result.unwrap(); + assert!( + supports_vision || supports_audio, + "Unified model {} should support vision or audio", + model_id + ); + + return; // Test passed with one unified model + } + } + + eprintln!("Skipping test - no unified model found"); +} + +#[tokio::test] +async fn integration_test_verify_nonexistent_model() { + let manager = ModelManager::new().unwrap(); + + let result = manager.verify_downloaded_model("totally-fake-model-xyz-123").await; + + assert!(result.is_err(), "Should fail to verify non-existent model"); +} + +#[tokio::test] +#[ignore] // Requires models to be downloaded +async fn integration_test_get_model_paths_after_verification() { + let manager = ModelManager::new().unwrap(); + + // Get list of downloaded models + let models = manager.list_downloaded_models().await.unwrap(); + + if models.is_empty() { + eprintln!("Skipping test - no models downloaded"); + return; + } + + for model_id in models { + println!("Testing model: {}", model_id); + + // Verify the model is multimodal + let verify_result = manager.verify_downloaded_model(&model_id).await; + + if let Ok((supports_vision, supports_audio)) = verify_result { + assert!( + supports_vision || supports_audio, + "Model {} should support vision or audio", + model_id + ); + + // Get model paths + let paths_result = manager.get_model_paths(&model_id).await; + assert!(paths_result.is_ok(), "Should get paths for verified model: {}", model_id); + + let (lang_path, mmproj_path) = paths_result.unwrap(); + assert!(lang_path.exists(), "Language model should exist: {:?}", lang_path); + assert!(mmproj_path.exists(), "MMProj should exist: {:?}", mmproj_path); + + println!(" ✓ Verified and got paths for {}", model_id); + println!(" Vision: {}, Audio: {}", supports_vision, supports_audio); + } + } +} + +// Test that the download flow properly rejects non-multimodal models +#[tokio::test] +#[ignore] // Would require attempting to download a text-only model +async fn integration_test_reject_text_only_model() { + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + + let temp_dir = TempDir::new().unwrap(); + let _manager = ModelManager::with_models_dir(temp_dir.path().to_path_buf()); + + // Note: This test is conceptual - we don't actually have a good + // text-only GGUF model in HuggingFace to test with + // The verification would happen after download and should clean up + // the files if the model doesn't support vision/audio + + eprintln!("Note: This test documents expected behavior"); + eprintln!("If a text-only model is downloaded, it should be rejected during verification"); + eprintln!("and the downloaded files should be cleaned up"); +} diff --git a/src/App.tsx b/src/App.tsx index fa83dc4..cf3cfef 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -7,7 +7,7 @@ import { ImageViewer } from "./components/ImageViewer"; import { ChatPanel } from "./components/ChatPanel"; import { ModelSelectionModal } from "./components/ModelSelectionModal"; import { DownloadModelDialog } from "./components/DownloadModelDialog"; -import type { MediaItem, Model, AvailableModel, ChatMessage } from "./types"; +import type { MediaItem, Model, AvailableModel, OnnxModel, ChatMessage } from "./types"; import "./App.css"; // Bundled model that ships with Baseweight Canvas @@ -34,7 +34,7 @@ const BUNDLED_MODEL: Model = { // Mock available models (for download) - Real models from HuggingFace collections const MOCK_AVAILABLE_MODELS: AvailableModel[] = [ - // Popular SmolVLM2 models (HuggingFace) + // SmolVLM2 2.2B Instruct - Best compact vision-language model { id: 'smolvlm2-2.2b-instruct', name: 'SmolVLM2-2.2B-Instruct-GGUF', @@ -47,72 +47,8 @@ const MOCK_AVAILABLE_MODELS: AvailableModel[] = [ quantization: 'Q4_K_M', description: 'Popular compact vision-language model from HuggingFace, excellent quality-to-size ratio', }, - { - id: 'ggml_org_smolvlm2_500m_video_instruct_gguf', - name: 'SmolVLM2-500M-Video-Instruct-GGUF', - displayName: 'SmolVLM2 500M Video', - task: 'general-vlm', - taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/ggml-org/SmolVLM2-500M-Video-Instruct-GGUF', - size: 0.3 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Tiny video-capable vision model, supports both images and video frames', - }, - - // PaliGemma (Google vision-language) - { - id: 'paligemma-3b', - name: 'paligemma-3b-mix-224-gguf', - displayName: 'PaliGemma 3B', - task: 'general-vlm', - taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/abetlen/paligemma-3b-mix-224-gguf', - size: 2.5 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Google\'s vision-language model based on Gemma, strong OCR and visual reasoning', - }, - // ggml-org multimodal models - { - id: 'mistral-small-3.1-24b', - name: 'Mistral-Small-3.1-24B-Instruct-2503-GGUF', - displayName: 'Mistral Small 3.1 24B Instruct', - task: 'general-vlm', - taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF', - size: 14 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Powerful 24B parameter vision-language model with strong instruction following', - }, - { - id: 'pixtral-12b', - name: 'pixtral-12b-GGUF', - displayName: 'Pixtral 12B', - task: 'general-vlm', - taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/ggml-org/pixtral-12b-GGUF', - size: 7 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'High-performance 12B vision model from Mistral AI', - }, - { - id: 'moondream2-2025', - name: 'moondream2-20250414-GGUF', - displayName: 'Moondream 2 (2025)', - task: 'general-vlm', - taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/ggml-org/moondream2-20250414-GGUF', - size: 0.3 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Ultra-compact 500M parameter vision model, perfect for edge devices', - }, - - // Audio models + // Ultravox - Audio-capable language model { id: 'ggml_org_ultravox_v0_5_llama_3_2_1b_gguf', name: 'ultravox-v0_5-llama-3_2-1b-GGUF', @@ -126,44 +62,54 @@ const MOCK_AVAILABLE_MODELS: AvailableModel[] = [ description: 'Audio-capable language model based on Llama 3.2, supports speech understanding', }, - // LiquidAI LFM2-VL models + // LFM2-VL 450M - Ultra-lightweight vision model { - id: 'lfm2-vl-3b', - name: 'LFM2-VL-3B-GGUF', - displayName: 'LFM2-VL 3B', + id: 'lfm2-vl-450m', + name: 'LFM2-VL-450M-GGUF', + displayName: 'LFM2-VL 450M', task: 'general-vlm', taskDescription: 'General Vision-Language Model', backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/LiquidAI/LFM2-VL-3B-GGUF', - size: 2 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'On-device optimized 3B vision-language model from Liquid AI', + huggingfaceUrl: 'https://huggingface.co/LiquidAI/LFM2-VL-450M-GGUF', + size: 0.3 * 1024 * 1024 * 1024, + quantization: 'Q4_0', + description: 'Ultra-lightweight 450M model for resource-constrained environments', }, + + // Ministral 3 14B - Mistral's latest vision model { - id: 'lfm2-vl-1.6b', - name: 'LFM2-VL-1.6B-GGUF', - displayName: 'LFM2-VL 1.6B', + id: 'mistralai_ministral_3_14b_instruct_2512_gguf', + name: 'Ministral-3-14B-Instruct-2512-GGUF', + displayName: 'Ministral 3 14B Instruct', task: 'general-vlm', taskDescription: 'General Vision-Language Model', backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/LiquidAI/LFM2-VL-1.6B-GGUF', - size: 1 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Efficient 1.6B model designed for on-device deployment', + huggingfaceUrl: 'https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512-GGUF', + size: 8.5 * 1024 * 1024 * 1024, + quantization: 'Q4_K_M (BF16 mmproj)', + description: 'Mistral\'s latest 14B vision-language model with unified architecture and strong performance', }, + +]; + +// ONNX models for download +const ONNX_MODELS: OnnxModel[] = [ { - id: 'lfm2-vl-450m', - name: 'LFM2-VL-450M-GGUF', - displayName: 'LFM2-VL 450M', + id: 'smolvlm2-256m-video-instruct-onnx', + name: 'SmolVLM2-256M-Video-Instruct-ONNX', + displayName: 'SmolVLM2 256M Video Instruct (ONNX)', task: 'general-vlm', taskDescription: 'General Vision-Language Model', - backend: 'llama.cpp', - huggingfaceUrl: 'https://huggingface.co/LiquidAI/LFM2-VL-450M-GGUF', - size: 0.3 * 1024 * 1024 * 1024, - quantization: 'Q4_K_M', - description: 'Ultra-lightweight 450M model for resource-constrained environments', + huggingfaceRepo: 'HuggingFaceTB/SmolVLM2-256M-Video-Instruct', + huggingfaceUrl: 'https://huggingface.co/HuggingFaceTB/SmolVLM2-256M-Video-Instruct', + quantizations: ['Q4', 'Q8', 'FP16'], + estimatedSizes: { + 'Q4': 0.3 * 1024 * 1024 * 1024, + 'Q8': 0.5 * 1024 * 1024 * 1024, + 'FP16': 0.8 * 1024 * 1024 * 1024, + }, + description: 'SmolVLM2 256M ONNX Runtime backend - compact model optimized for GPU acceleration', }, - ]; interface DownloadProgress { @@ -231,6 +177,7 @@ function App() { const loadDownloadedModels = async () => { try { const modelIds = await invoke('list_downloaded_models'); + console.log('Downloaded model IDs:', modelIds); // Convert model IDs to full Model objects const models: Model[] = modelIds.map(modelId => { @@ -239,7 +186,38 @@ function App() { return BUNDLED_MODEL; } - // Find in available models + // Check if it's an ONNX model + if (modelId.includes('_onnx_')) { + const baseName = modelId.split('_onnx_')[0]; + const quantFromId = modelId.split('_onnx_')[1]; + console.log(`ONNX model detected: ${modelId}, baseName: ${baseName}, quant: ${quantFromId}`); + + const onnxModel = ONNX_MODELS.find(m => { + const normalizedRepo = m.huggingfaceRepo.replace('/', '_').replace(/-/g, '_').toLowerCase(); + console.log(` Comparing baseName "${baseName}" with normalized "${normalizedRepo}"`); + return normalizedRepo === baseName; + }); + + console.log(` Found ONNX model match:`, onnxModel ? onnxModel.displayName : 'NO MATCH'); + + if (onnxModel) { + return { + id: modelId, + name: onnxModel.name, + displayName: `${onnxModel.displayName} (${quantFromId.toUpperCase()})`, + task: onnxModel.task, + taskDescription: onnxModel.taskDescription, + backend: 'onnx-runtime' as const, + huggingfaceUrl: onnxModel.huggingfaceUrl, + size: onnxModel.estimatedSizes[quantFromId.toUpperCase()] || 0, + downloaded: true, + quantization: quantFromId.toUpperCase(), + localPath: `/models/${modelId}`, + } as Model; + } + } + + // Find in available models (GGUF models) const availableModel = MOCK_AVAILABLE_MODELS.find(m => m.id === modelId); if (availableModel) { return { @@ -366,23 +344,36 @@ function App() { try { setIsModelLoading(true); - console.log('Loading model:', currentModel.id); + console.log('Loading model:', currentModel.id, 'backend:', currentModel.backend); - await invoke('load_model', { - modelId: currentModel.id, - nGpuLayers: 999, // Use all available GPU layers - }); - - console.log('Model loaded successfully'); + // Load based on backend type + if (currentModel.backend === 'onnx-runtime') { + console.log('Loading ONNX model'); + await invoke('load_onnx_model', { + modelId: currentModel.id, + }); + console.log('ONNX model loaded successfully'); - // Check if model supports audio - try { - const supportsAudio = await invoke('check_audio_support'); - setIsAudioCapable(supportsAudio); - console.log('Model audio support:', supportsAudio); - } catch (error) { - console.error('Failed to check audio support:', error); + // ONNX models don't support audio yet setIsAudioCapable(false); + } else { + // llama.cpp backend + console.log('Loading llama.cpp model'); + await invoke('load_model', { + modelId: currentModel.id, + nGpuLayers: 999, // Use all available GPU layers + }); + console.log('llama.cpp model loaded successfully'); + + // Check if model supports audio (llama.cpp only) + try { + const supportsAudio = await invoke('check_audio_support'); + setIsAudioCapable(supportsAudio); + console.log('Model audio support:', supportsAudio); + } catch (error) { + console.error('Failed to check audio support:', error); + setIsAudioCapable(false); + } } setIsModelLoading(false); @@ -588,13 +579,40 @@ function App() { console.log('Sending conversation with', conversation.length, 'messages'); - // Call inference with full conversation - response = await invoke('generate_response', { - conversation, - imageData: rgbData, - imageWidth: img.width, - imageHeight: img.height, - }); + // Call inference based on backend type + if (currentModel.backend === 'onnx-runtime') { + // ONNX Runtime - simpler single-turn for now + // Convert RGBA to raw bytes for ONNX + const canvas2 = document.createElement('canvas'); + canvas2.width = img.width; + canvas2.height = img.height; + const ctx2 = canvas2.getContext('2d'); + if (!ctx2) throw new Error('Failed to get canvas context'); + ctx2.drawImage(img, 0, 0); + + // Get image as JPEG bytes + const blob = await new Promise((resolve) => { + canvas2.toBlob((b) => resolve(b!), 'image/jpeg', 0.95); + }); + const arrayBuffer = await blob.arrayBuffer(); + const imageBytes = Array.from(new Uint8Array(arrayBuffer)); + + console.log('Calling ONNX inference with image:', img.width, 'x', img.height); + response = await invoke('generate_onnx_response', { + prompt: content, + imageData: imageBytes, + imageWidth: img.width, + imageHeight: img.height, + }); + } else { + // llama.cpp backend - full conversation support + response = await invoke('generate_response', { + conversation, + imageData: rgbData, + imageWidth: img.width, + imageHeight: img.height, + }); + } } const assistantMessage: ChatMessage = { @@ -666,7 +684,34 @@ function App() { return BUNDLED_MODEL; } - // Find in available models + // Check if it's an ONNX model + if (modelId.includes('_onnx_')) { + const baseName = modelId.split('_onnx_')[0]; + const quantFromId = modelId.split('_onnx_')[1]; + + const onnxModel = ONNX_MODELS.find(m => { + const normalizedRepo = m.huggingfaceRepo.replace('/', '_').replace(/-/g, '_').toLowerCase(); + return normalizedRepo === baseName; + }); + + if (onnxModel) { + return { + id: modelId, + name: onnxModel.name, + displayName: `${onnxModel.displayName} (${quantFromId.toUpperCase()})`, + task: onnxModel.task, + taskDescription: onnxModel.taskDescription, + backend: 'onnx-runtime' as const, + huggingfaceUrl: onnxModel.huggingfaceUrl, + size: onnxModel.estimatedSizes[quantFromId.toUpperCase()] || 0, + downloaded: true, + quantization: quantFromId.toUpperCase(), + localPath: `/models/${modelId}`, + } as Model; + } + } + + // Find in available models (GGUF models) const availableModel = MOCK_AVAILABLE_MODELS.find(m => m.id === modelId); if (availableModel) { return { @@ -711,19 +756,149 @@ function App() { } }; - const handleAddModel = (repo: string, quantization: string) => { + const handleDownloadOnnxModel = async (repo: string, quantization: string) => { + console.log('Download ONNX model:', repo, quantization); + + setIsDownloading(true); + setIsDownloadDialogOpen(true); + setFileProgress({}); + + try { + // Start the ONNX model download + const modelId = await invoke('download_onnx_model', { + repo, + quantization, + }); + + console.log('ONNX model download completed, model ID:', modelId); + + // Refresh downloaded models list + const downloadedModelIds = await invoke('list_downloaded_models'); + const newDownloadedModels = downloadedModelIds.map(id => { + // Check if it's a ONNX model + if (id.includes('_onnx_')) { + const onnxModel = ONNX_MODELS.find(m => modelId === id || m.huggingfaceRepo.replace('/', '_').toLowerCase().includes(id.split('_onnx_')[0])); + if (onnxModel) { + return { + id: modelId, + name: onnxModel.name, + displayName: onnxModel.displayName, + task: onnxModel.task, + taskDescription: onnxModel.taskDescription, + backend: 'onnx-runtime' as const, + huggingfaceUrl: onnxModel.huggingfaceUrl, + size: onnxModel.estimatedSizes[quantization] || 0, + downloaded: true, + quantization, + localPath: `/models/${modelId}`, + }; + } + } + + // Otherwise handle as GGUF model + if (id === BUNDLED_MODEL.id) { + return BUNDLED_MODEL; + } + + const availableModel = MOCK_AVAILABLE_MODELS.find(m => m.id === id); + if (availableModel) { + return { + ...availableModel, + localPath: `/models/${id}`, + downloaded: true, + }; + } + + return null; + }).filter((m): m is Model => m !== null); + + setDownloadedModels(newDownloadedModels); + + setIsDownloading(false); + setIsDownloadDialogOpen(false); + + if (modelId) { + alert(`ONNX model downloaded successfully!\nModel ID: ${modelId}`); + } + } catch (error) { + console.error('Failed to download ONNX model:', error); + alert(`Failed to download ONNX model: ${error}`); + setIsDownloading(false); + setIsDownloadDialogOpen(false); + } + }; + + const handleAddModel = async (repo: string, quantization: string) => { console.log('Add model from HuggingFace:', repo, quantization); - // TODO: Implement actual HuggingFace download logic - alert( - `Add Model from HuggingFace\n\n` + - `Repository: ${repo}\n` + - `Quantization: ${quantization}\n\n` + - `This will be implemented with the Rust backend:\n` + - `1. Fetch model card from HuggingFace\n` + - `2. Locate mmproj-*.gguf and model-${quantization.toLowerCase()}.gguf\n` + - `3. Download both files\n` + - `4. Add to downloaded models list` - ); + + try { + setIsDownloading(true); + setFileProgress({}); // Clear previous download progress + setIsDownloadDialogOpen(true); // Show download dialog + setIsModelModalOpen(false); // Close model selection modal + downloadCancelledRef.current = false; // Reset cancel flag + + // Call the download_model command with quantization + const downloadedModelId = await invoke('download_model', { repo, quantization }); + + // Model download complete + setIsDownloading(false); + setIsDownloadDialogOpen(false); + + // Refresh the entire downloaded models list from backend + const modelIds = await invoke('list_downloaded_models'); + const models: Model[] = modelIds.map(modelId => { + // Check if it's the bundled model + if (modelId === BUNDLED_MODEL.id) { + return BUNDLED_MODEL; + } + + // Find in available models + const availableModel = MOCK_AVAILABLE_MODELS.find(m => m.id === modelId); + if (availableModel) { + return { + ...availableModel, + localPath: `/models/${modelId}`, + downloaded: true, + } as Model; + } + + // Fallback for unknown models (custom models added via "Add Model") + return { + id: modelId, + name: modelId, + displayName: modelId.split('_').map(word => + word.charAt(0).toUpperCase() + word.slice(1) + ).join(' '), + task: 'general-vlm', + taskDescription: 'General Vision-Language Model', + backend: 'llama.cpp', + huggingfaceUrl: `https://huggingface.co/${repo}`, + size: 0, + localPath: `/models/${modelId}`, + downloaded: true, + quantization: quantization, + } as Model; + }); + + setDownloadedModels(models); + + // Only load the model if download wasn't cancelled + if (!downloadCancelledRef.current) { + // Find and load the downloaded model + const downloadedModel = models.find(m => m.id === downloadedModelId); + if (downloadedModel) { + setCurrentModel(downloadedModel); + } + + alert(`Model downloaded successfully!\nModel ID: ${downloadedModelId}`); + } + } catch (error) { + console.error('Failed to download model:', error); + alert(`Failed to download model: ${error}`); + setIsDownloading(false); + setIsDownloadDialogOpen(false); + } }; const handleDownloadBundledModel = async () => { @@ -783,8 +958,10 @@ function App() { currentModel={currentModel} downloadedModels={downloadedModels} availableModels={MOCK_AVAILABLE_MODELS} + onnxModels={ONNX_MODELS} onSelectModel={handleSelectModel} onDownloadModel={handleDownloadModel} + onDownloadOnnxModel={handleDownloadOnnxModel} onAddModel={handleAddModel} /> void; onDownloadModel: (modelId: string) => void; + onDownloadOnnxModel: (repo: string, quantization: string) => void; onAddModel: (repo: string, quantization: string) => void; } @@ -20,12 +22,15 @@ export const ModelSelectionModal: React.FC = ({ currentModel, downloadedModels, availableModels, + onnxModels, onSelectModel, onDownloadModel, + onDownloadOnnxModel, onAddModel, }) => { - const [activeTab, setActiveTab] = useState<'downloaded' | 'available'>('downloaded'); + const [activeTab, setActiveTab] = useState<'downloaded' | 'available' | 'onnx'>('downloaded'); const [isAddDialogOpen, setIsAddDialogOpen] = useState(false); + const [selectedQuantizations, setSelectedQuantizations] = useState<{ [key: string]: string }>({}); if (!isOpen) return null; @@ -80,6 +85,12 @@ export const ModelSelectionModal: React.FC = ({ > Recommended Models ({availableModels.length}) +
@@ -149,7 +160,7 @@ export const ModelSelectionModal: React.FC = ({ )) )}
- ) : ( + ) : activeTab === 'available' ? (
{availableModels.map(model => { const isDownloaded = downloadedModels.some(m => m.id === model.id); @@ -206,6 +217,71 @@ export const ModelSelectionModal: React.FC = ({ ); })}
+ ) : ( +
+ {onnxModels.map(model => { + const selectedQuant = selectedQuantizations[model.id] || model.quantizations[0]; + const estimatedSize = model.estimatedSizes[selectedQuant] || 0; + + return ( +
+
+
+

{model.displayName}

+
+ + {model.taskDescription} + + ONNX Runtime +
+ {model.description && ( +

{model.description}

+ )} +
+ + +
+
+
{formatSize(estimatedSize)}
+
+ +
+ + View on HuggingFace → + + +
+
+ ); + })} +
)} diff --git a/src/types/index.ts b/src/types/index.ts index 3ce6f9c..baff8b6 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -93,3 +93,16 @@ export interface AvailableModel { quantization?: string; description?: string; } + +export interface OnnxModel { + id: string; + name: string; + displayName: string; + task: ModelTask; + taskDescription: string; + huggingfaceRepo: string; + huggingfaceUrl: string; + quantizations: string[]; // Available quantizations (Q4, Q8, FP16) + estimatedSizes: { [key: string]: number }; // Size per quantization + description?: string; +}