diff --git a/Cargo.lock b/Cargo.lock index 3f6ed61fba..a6e9d70a72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,9 +40,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom 0.2.15", "once_cell", - "serde", "version_check", "zerocopy 0.7.35", ] @@ -165,9 +163,9 @@ checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063" [[package]] name = "async-trait" -version = "0.1.87" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d556ec1359574147ec0c4fc5eb525f3f23263a592b1a9c07e0a75b427de55c97" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", @@ -328,12 +326,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "borrow-or-share" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32" - [[package]] name = "bumpalo" version = "3.17.0" @@ -429,27 +421,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "candle-hf-hub" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5f45ce8fe55a9e9246a3fc60000d7ed11b88a84d72f753488f7264ce04b102" -dependencies = [ - "dirs", - "futures", - "http", - "indicatif", - "log", - "num_cpus", - "rand 0.8.5", - "reqwest", - "serde", - "serde_json", - "thiserror 1.0.69", - "tokio", - "ureq", -] - [[package]] name = "candle-kernels" version = "0.8.0" @@ -487,25 +458,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "cbindgen" -version = "0.27.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb" -dependencies = [ - "clap", - "heck 0.4.1", - "indexmap", - "log", - "proc-macro2", - "quote", - "serde", - "serde_json", - "syn 2.0.100", - "tempfile", - "toml", -] - [[package]] name = "cc" version = "1.2.16" @@ -606,7 +558,7 @@ version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.100", @@ -966,15 +918,16 @@ dependencies = [ [[package]] name = "derivre" -version = "0.1.0" -source = "git+https://github.com/microsoft/derivre?rev=02ee497e6e404a0b402b4f68a9abf599d22ed2ed#02ee497e6e404a0b402b4f68a9abf599d22ed2ed" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a3c2606b3ffc46f91fd62d954d55659ba9fb391bb673311b70f50daf9c15e49" dependencies = [ - "ahash", "anyhow", "bytemuck", "bytemuck_derive", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "regex-syntax 0.8.5", + "strum 0.27.1", ] [[package]] @@ -1086,7 +1039,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.100", @@ -1142,12 +1095,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - [[package]] name = "fdeflate" version = "0.3.7" @@ -1192,32 +1139,12 @@ dependencies = [ "rand_distr", ] -[[package]] -name = "fluent-uri" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" -dependencies = [ - "borrow-or-share", - "ref-cast", - "serde", -] - [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared 0.1.1", -] - [[package]] name = "foreign-types" version = "0.5.0" @@ -1225,7 +1152,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared 0.3.1", + "foreign-types-shared", ] [[package]] @@ -1239,12 +1166,6 @@ dependencies = [ "syn 2.0.100", ] -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1595,12 +1516,6 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -1620,14 +1535,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1" dependencies = [ "dirs", + "futures", "http", "indicatif", "libc", "log", + "num_cpus", "rand 0.8.5", + "reqwest", "serde", "serde_json", "thiserror 2.0.12", + "tokio", "ureq", "windows-sys 0.59.0", ] @@ -1717,22 +1636,6 @@ dependencies = [ "webpki-roots", ] -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", -] - [[package]] name = "hyper-util" version = "0.1.10" @@ -1987,15 +1890,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", -] - [[package]] name = "intel-mkl-src" version = "0.8.1" @@ -2142,21 +2036,16 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "llguidance" -version = "0.4.1" -source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d74976ff41cc9cb78bca6cc#cfef3df97372a7b84d74976ff41cc9cb78bca6cc" +version = "0.7.0" +source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" dependencies = [ "anyhow", - "cbindgen", "derivre", "indexmap", - "instant", - "referencing", "regex-syntax 0.8.5", - "rustc-hash", "serde", "serde_json", "toktrie", - "url", ] [[package]] @@ -2262,7 +2151,7 @@ dependencies = [ "bitflags 2.9.0", "block", "core-graphics-types", - "foreign-types 0.5.0", + "foreign-types", "log", "objc", "paste", @@ -2394,7 +2283,6 @@ dependencies = [ "candle-core", "candle-flash-attn", "candle-flash-attn-v3", - "candle-hf-hub", "candle-nn", "cfgrammar", "chrono", @@ -2408,6 +2296,7 @@ dependencies = [ "futures", "galil-seiferas", "half", + "hf-hub", "image", "indexmap", "indicatif", @@ -2440,7 +2329,7 @@ dependencies = [ "serde_json", "serde_plain", "serde_yaml", - "strum", + "strum 0.26.3", "sysinfo", "thiserror 1.0.69", "tokenizers", @@ -2504,6 +2393,7 @@ dependencies = [ "candle-nn", "float8", "half", + "hf-hub", "lazy_static", "memmap2", "metal", @@ -2579,23 +2469,6 @@ dependencies = [ "syn 2.0.100", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nibble_vec" version = "0.1.0" @@ -2771,8 +2644,8 @@ dependencies = [ "getset", "serde", "serde_json", - "strum", - "strum_macros", + "strum 0.26.3", + "strum_macros 0.26.4", "thiserror 1.0.69", ] @@ -2831,50 +2704,6 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "openssl" -version = "0.10.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" -dependencies = [ - "bitflags 2.9.0", - "cfg-if", - "foreign-types 0.3.2", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.106" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -3170,7 +2999,7 @@ version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "pyo3-build-config", "quote", @@ -3414,39 +3243,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "ref-cast" -version = "1.0.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" -dependencies = [ - "ref-cast-impl", -] - -[[package]] -name = "ref-cast-impl" -version = "1.0.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - -[[package]] -name = "referencing" -version = "0.26.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb8e15af8558cb157432dd3d88c1d1e982d0a5755cf80ce593b6499260aebc49" -dependencies = [ - "ahash", - "fluent-uri", - "once_cell", - "percent-encoding", - "serde_json", -] - [[package]] name = "regex" version = "1.11.1" @@ -3509,13 +3305,11 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", - "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", - "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -3529,13 +3323,14 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", @@ -3708,15 +3503,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "schemars" version = "0.8.22" @@ -3747,29 +3533,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.9.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.26" @@ -4040,7 +3803,16 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ - "strum_macros", + "strum_macros 0.26.4", +] + +[[package]] +name = "strum" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" +dependencies = [ + "strum_macros 0.27.1", ] [[package]] @@ -4049,7 +3821,20 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.100", +] + +[[package]] +name = "strum_macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +dependencies = [ + "heck", "proc-macro2", "quote", "rustversion", @@ -4171,19 +3956,6 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" -[[package]] -name = "tempfile" -version = "3.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488960f40a3fd53d72c2a29a58722561dee8afdd175bd88e3db4677d7b2ba600" -dependencies = [ - "fastrand", - "getrandom 0.3.1", - "once_cell", - "rustix", - "windows-sys 0.59.0", -] - [[package]] name = "termcolor" version = "1.4.1" @@ -4299,7 +4071,6 @@ dependencies = [ "derive_builder", "esaxx-rs", "getrandom 0.2.15", - "hf-hub", "indicatif", "itertools 0.13.0", "lazy_static", @@ -4351,16 +4122,6 @@ dependencies = [ "syn 2.0.100", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rayon" version = "2.1.0" @@ -4396,25 +4157,23 @@ dependencies = [ [[package]] name = "toktrie" -version = "0.1.0" -source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d74976ff41cc9cb78bca6cc#cfef3df97372a7b84d74976ff41cc9cb78bca6cc" +version = "0.7.0" +source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" dependencies = [ "anyhow", "bytemuck", "bytemuck_derive", - "rustc-hash", "serde", "serde_json", ] [[package]] name = "toktrie_hf_tokenizers" -version = "0.1.0" -source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d74976ff41cc9cb78bca6cc#cfef3df97372a7b84d74976ff41cc9cb78bca6cc" +version = "0.7.0" +source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" dependencies = [ "anyhow", "log", - "rustc-hash", "serde", "serde_json", "tokenizers", @@ -4782,12 +4541,6 @@ dependencies = [ "uuid 0.8.2", ] -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" @@ -4910,6 +4663,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -4947,9 +4713,9 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "widestring" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311" +checksum = "dd7cf3379ca1aac9eea11fba24fd7e315d621f8dfe35c8d7d2be8b793726e07d" [[package]] name = "winapi" diff --git a/Cargo.toml b/Cargo.toml index 28563e787c..0a9093d8c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] } once_cell = "1.19.0" # All features but avif, avif increases the msrv dramatically image = { version = "0.25.1", default-features = false, features = ['bmp', 'dds', 'exr', 'ff', 'gif', 'hdr', 'ico', 'jpeg', 'png', 'pnm', 'qoi', 'tga', 'tiff', 'webp']} -reqwest = { version = "0.12.4", features = ["blocking"] } +reqwest = { version = "0.12.4", default-features = false, features = ["blocking", "rustls-tls", "charset", "http2", "macos-system-configuration"] } base64 = "0.22.1" half = "2.4.0" rayon = "1.1.0" @@ -52,3 +52,4 @@ regex = "1.10.6" metal = { version = "0.27.0", features = ["mps"] } safetensors = "0.4.5" toml = "0.8.12" +hf-hub = { version = "0.4.1", default-features = false, features = ["ureq", "tokio", "rustls-tls"] } diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index bcc2b17305..3a0f02c27b 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -19,9 +19,9 @@ serde.workspace = true serde_json.workspace = true candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "496a8d2b", optional = true } dirs = "5.0.1" -hf-hub = { version = "0.3.3", package = "candle-hf-hub" } +hf-hub.workspace = true thiserror = "1.0.57" -tokenizers = "0.21.0" +tokenizers = { version = "0.21.0", default-features = false } tqdm = "0.7.0" chrono = "0.4.34" minijinja = { version = "2.0.2", features = ["builtins", "json"] } @@ -77,8 +77,8 @@ regex.workspace = true serde_plain = "1.0.2" as-any = "0.3.1" float8.workspace = true -llguidance = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97372a7b84d74976ff41cc9cb78bca6cc", default-features = false, features = ["lark"] } -toktrie_hf_tokenizers = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97372a7b84d74976ff41cc9cb78bca6cc" } +llguidance = { git = "https://github.com/EricLBuehler/llguidance", rev = "8d71957", default-features = false, features = ["lark"] } +toktrie_hf_tokenizers = { git = "https://github.com/EricLBuehler/llguidance", rev = "8d71957" } objc = { version = "0.2.7", optional = true } metal = { workspace = true, optional = true } candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "496a8d2b", optional = true } diff --git a/mistralrs-core/src/pipeline/llg.rs b/mistralrs-core/src/pipeline/llg.rs index ec0212775e..c57b080a43 100644 --- a/mistralrs-core/src/pipeline/llg.rs +++ b/mistralrs-core/src/pipeline/llg.rs @@ -2,10 +2,9 @@ use std::sync::Arc; use anyhow::Result; use llguidance::{ - api::{ParserLimits, RegexNode, TopLevelGrammar}, - lark_to_llguidance, + api::{ParserLimits, TopLevelGrammar}, toktrie::{InferenceCapabilities, TokEnv}, - JsonCompileOptions, TokenParser, + TokenParser, }; use tokenizers::Tokenizer; @@ -21,13 +20,9 @@ pub fn build_tok_env(tokenizer: Tokenizer) -> TokEnv { pub fn llg_grammar_from_constraint(constraint: &Constraint) -> Result> { let grm = match constraint { - Constraint::Regex(regex) => { - TopLevelGrammar::from_regex(RegexNode::Regex(regex.to_string())) - } - Constraint::Lark(lark) => lark_to_llguidance(lark)?, - Constraint::JsonSchema(value) => { - JsonCompileOptions::default().json_to_llg_no_validate(value.clone())? - } + Constraint::Regex(regex) => TopLevelGrammar::from_regex(regex), + Constraint::Lark(lark) => TopLevelGrammar::from_lark(lark.clone()), + Constraint::JsonSchema(value) => TopLevelGrammar::from_json_schema(value.clone()), Constraint::Llguidance(value) => value.clone(), Constraint::None => return Ok(None), }; @@ -38,7 +33,7 @@ pub fn constraint_from_llg_grammar( tok_env: TokEnv, grm: TopLevelGrammar, ) -> Result { - let parser = TokenParser::from_llguidance_json( + let parser = TokenParser::from_grammar( tok_env, grm, llguidance::Logger::new(0, 1), diff --git a/mistralrs-quant/Cargo.toml b/mistralrs-quant/Cargo.toml index ba62e797ed..6dd3b3da6d 100644 --- a/mistralrs-quant/Cargo.toml +++ b/mistralrs-quant/Cargo.toml @@ -30,6 +30,7 @@ yoke = "0.7.5" memmap2 = "0.9.5" safetensors.workspace = true regex.workspace = true +hf-hub.workspace = true [features] cuda = [ diff --git a/mistralrs-quant/src/blockwise_fp8/ops.rs b/mistralrs-quant/src/blockwise_fp8/ops.rs index 04b54177d7..822a9a7248 100644 --- a/mistralrs-quant/src/blockwise_fp8/ops.rs +++ b/mistralrs-quant/src/blockwise_fp8/ops.rs @@ -250,11 +250,14 @@ pub fn fp8_blockwise_dequantize( } #[cfg(test)] +#[allow(unused_imports)] mod tests { use candle_core::{DType, Device, Result, Tensor}; + use candle_nn::{Linear, Module}; use half::bf16; + use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; - use crate::blockwise_fp8::ops; + use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors}; #[test] fn test_fp8_blockwise_dequant() -> Result<()> { @@ -455,4 +458,46 @@ mod tests { Ok(()) } + + #[cfg(feature = "cuda")] + #[test] + fn test_blockwise_fp8_gemm() -> Result<()> { + let dev = Device::cuda_if_available(0)?; + + let api = ApiBuilder::new().with_progress(true).build().unwrap(); + let api = api.repo(Repo::with_revision( + "EricB/mistralrs_tests".to_string(), + RepoType::Model, + "main".to_string(), + )); + + let filename = api.get("test_fp8.safetensors").unwrap(); + let vb = unsafe { MmapedSafetensors::new(filename)? }; + + let weight = vb.load("weight", &dev, None)?; + assert_eq!((7168, 2048), weight.dims2()?); + assert_eq!(DType::F8E4M3, weight.dtype()); + + let scale = vb.load("scale", &dev, None)?; + assert_eq!((56, 16), scale.dims2()?); + assert_eq!(DType::F32, scale.dtype()); + + let weight_block_size = vec![128, 128]; + + // in dim is 2048. + let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?; + + let truth = { + let weight_dq = + ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?; + + let lin_dq = Linear::new(weight_dq, None); + lin_dq.forward(&xs)? + }; + + // TODO: will be adding real blockwise fp8 gemm shortly ;) + assert_eq!((32, 7168), truth.dims2()?); + + Ok(()) + } } diff --git a/mistralrs-quant/src/distributed/socket.rs b/mistralrs-quant/src/distributed/socket.rs index 6072de5629..97b50f1187 100644 --- a/mistralrs-quant/src/distributed/socket.rs +++ b/mistralrs-quant/src/distributed/socket.rs @@ -47,6 +47,7 @@ impl Server { pub fn broadcast_id(&self, id: &Id) -> Result<()> { let body = id.internal(); // SAFETY: We know the provenance and lifetime of `body` are valid. + #[allow(clippy::unnecessary_cast)] let body_bytes = unsafe { slice::from_raw_parts(body.as_ptr() as *const u8, body.len()) }; for mut stream in &self.connections { stream.write_all(body_bytes)?; diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index 3694bb2484..d8835ae0f4 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -152,11 +152,12 @@ impl HqqBits { (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize, wq_in.dims()[1], ), - DType::I32, + DType::U32, wq_in.device(), )?; - let wq = - wq.slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::I32)?)?; + let wq = wq + .slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)? + .to_dtype(DType::I32)?; let step = (wq.dims()[0] as f64 / 10.) as usize; let a = wq.narrow(0, 0, step)?; diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index 7deae75ed5..b34cba2d2c 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -68,9 +68,24 @@ impl CustomOp2 for BitWiseOr { let result = CpuStorage::U8(result); Ok((result, l1.shape().clone())) } - CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")), - CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or")), - CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or")), + CpuStorage::I16(vs1) => { + let vs2 = &s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I16(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::U32(vs1) => { + let vs2 = &s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::U32(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::I64(vs1) => { + let vs2 = &s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I64(result); + Ok((result, l1.shape().clone())) + } CpuStorage::I32(vs1) => { let vs2 = &s2.as_slice::().unwrap(); let result = self.bitwise(vs1, vs2); @@ -284,9 +299,21 @@ impl CustomOp1 for Leftshift { let result = CpuStorage::U8(result); Ok((result, l1.shape().clone())) } - CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshifr")), - CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshifr")), - CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshifr")), + CpuStorage::I16(vs1) => { + let result = self.leftshift(vs1); + let result = CpuStorage::I16(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::U32(vs1) => { + let result = self.leftshift(vs1); + let result = CpuStorage::U32(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::I64(vs1) => { + let result = self.leftshift(vs1); + let result = CpuStorage::I64(result); + Ok((result, l1.shape().clone())) + } CpuStorage::I32(vs1) => { let result = self.leftshift(vs1); let result = CpuStorage::I32(result); diff --git a/mistralrs/examples/llguidance/main.rs b/mistralrs/examples/llguidance/main.rs index 6a7ad08e28..b857d2c752 100644 --- a/mistralrs/examples/llguidance/main.rs +++ b/mistralrs/examples/llguidance/main.rs @@ -33,7 +33,6 @@ async fn main() -> Result<()> { .set_constraint(mistralrs::Constraint::Llguidance(LlguidanceGrammar { grammars: vec![top, schema], max_tokens: None, - test_trace: false, })) .set_sampler_max_len(100) .add_message(