Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 68 additions & 302 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] }
once_cell = "1.19.0"
# All features but avif, avif increases the msrv dramatically
image = { version = "0.25.1", default-features = false, features = ['bmp', 'dds', 'exr', 'ff', 'gif', 'hdr', 'ico', 'jpeg', 'png', 'pnm', 'qoi', 'tga', 'tiff', 'webp']}
reqwest = { version = "0.12.4", features = ["blocking"] }
reqwest = { version = "0.12.4", default-features = false, features = ["blocking", "rustls-tls", "charset", "http2", "macos-system-configuration"] }
base64 = "0.22.1"
half = "2.4.0"
rayon = "1.1.0"
Expand All @@ -52,3 +52,4 @@ regex = "1.10.6"
metal = { version = "0.27.0", features = ["mps"] }
safetensors = "0.4.5"
toml = "0.8.12"
hf-hub = { version = "0.4.1", default-features = false, features = ["ureq", "tokio", "rustls-tls"] }
8 changes: 4 additions & 4 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "496a8d2b", optional = true }
dirs = "5.0.1"
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
hf-hub.workspace = true
thiserror = "1.0.57"
tokenizers = "0.21.0"
tokenizers = { version = "0.21.0", default-features = false }
tqdm = "0.7.0"
chrono = "0.4.34"
minijinja = { version = "2.0.2", features = ["builtins", "json"] }
Expand Down Expand Up @@ -77,8 +77,8 @@ regex.workspace = true
serde_plain = "1.0.2"
as-any = "0.3.1"
float8.workspace = true
llguidance = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97372a7b84d74976ff41cc9cb78bca6cc", default-features = false, features = ["lark"] }
toktrie_hf_tokenizers = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97372a7b84d74976ff41cc9cb78bca6cc" }
llguidance = { git = "https://github.com/EricLBuehler/llguidance", rev = "8d71957", default-features = false, features = ["lark"] }
toktrie_hf_tokenizers = { git = "https://github.com/EricLBuehler/llguidance", rev = "8d71957" }
objc = { version = "0.2.7", optional = true }
metal = { workspace = true, optional = true }
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "496a8d2b", optional = true }
Expand Down
17 changes: 6 additions & 11 deletions mistralrs-core/src/pipeline/llg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ use std::sync::Arc;

use anyhow::Result;
use llguidance::{
api::{ParserLimits, RegexNode, TopLevelGrammar},
lark_to_llguidance,
api::{ParserLimits, TopLevelGrammar},
toktrie::{InferenceCapabilities, TokEnv},
JsonCompileOptions, TokenParser,
TokenParser,
};
use tokenizers::Tokenizer;

Expand All @@ -21,13 +20,9 @@ pub fn build_tok_env(tokenizer: Tokenizer) -> TokEnv {

pub fn llg_grammar_from_constraint(constraint: &Constraint) -> Result<Option<TopLevelGrammar>> {
let grm = match constraint {
Constraint::Regex(regex) => {
TopLevelGrammar::from_regex(RegexNode::Regex(regex.to_string()))
}
Constraint::Lark(lark) => lark_to_llguidance(lark)?,
Constraint::JsonSchema(value) => {
JsonCompileOptions::default().json_to_llg_no_validate(value.clone())?
}
Constraint::Regex(regex) => TopLevelGrammar::from_regex(regex),
Constraint::Lark(lark) => TopLevelGrammar::from_lark(lark.clone()),
Constraint::JsonSchema(value) => TopLevelGrammar::from_json_schema(value.clone()),
Constraint::Llguidance(value) => value.clone(),
Constraint::None => return Ok(None),
};
Expand All @@ -38,7 +33,7 @@ pub fn constraint_from_llg_grammar(
tok_env: TokEnv,
grm: TopLevelGrammar,
) -> Result<llguidance::Constraint> {
let parser = TokenParser::from_llguidance_json(
let parser = TokenParser::from_grammar(
tok_env,
grm,
llguidance::Logger::new(0, 1),
Expand Down
1 change: 1 addition & 0 deletions mistralrs-quant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ yoke = "0.7.5"
memmap2 = "0.9.5"
safetensors.workspace = true
regex.workspace = true
hf-hub.workspace = true

[features]
cuda = [
Expand Down
47 changes: 46 additions & 1 deletion mistralrs-quant/src/blockwise_fp8/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,14 @@ pub fn fp8_blockwise_dequantize(
}

#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Linear, Module};
use half::bf16;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};

use crate::blockwise_fp8::ops;
use crate::{blockwise_fp8::ops, safetensors::MmapedSafetensors};

#[test]
fn test_fp8_blockwise_dequant() -> Result<()> {
Expand Down Expand Up @@ -455,4 +458,46 @@ mod tests {

Ok(())
}

#[cfg(feature = "cuda")]
#[test]
fn test_blockwise_fp8_gemm() -> Result<()> {
let dev = Device::cuda_if_available(0)?;

let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));

let filename = api.get("test_fp8.safetensors").unwrap();
let vb = unsafe { MmapedSafetensors::new(filename)? };

let weight = vb.load("weight", &dev, None)?;
assert_eq!((7168, 2048), weight.dims2()?);
assert_eq!(DType::F8E4M3, weight.dtype());

let scale = vb.load("scale", &dev, None)?;
assert_eq!((56, 16), scale.dims2()?);
assert_eq!(DType::F32, scale.dtype());

let weight_block_size = vec![128, 128];

// in dim is 2048.
let xs = Tensor::randn(0f32, 1f32, (32, 2048), &dev)?.to_dtype(DType::BF16)?;

let truth = {
let weight_dq =
ops::fp8_blockwise_dequantize(&weight, &scale, weight_block_size, DType::BF16)?;

let lin_dq = Linear::new(weight_dq, None);
lin_dq.forward(&xs)?
};

// TODO: will be adding real blockwise fp8 gemm shortly ;)
assert_eq!((32, 7168), truth.dims2()?);

Ok(())
}
}
1 change: 1 addition & 0 deletions mistralrs-quant/src/distributed/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl Server {
pub fn broadcast_id(&self, id: &Id) -> Result<()> {
let body = id.internal();
// SAFETY: We know the provenance and lifetime of `body` are valid.
#[allow(clippy::unnecessary_cast)]
let body_bytes = unsafe { slice::from_raw_parts(body.as_ptr() as *const u8, body.len()) };
for mut stream in &self.connections {
stream.write_all(body_bytes)?;
Expand Down
7 changes: 4 additions & 3 deletions mistralrs-quant/src/hqq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ impl HqqBits {
(10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize,
wq_in.dims()[1],
),
DType::I32,
DType::U32,
wq_in.device(),
)?;
let wq =
wq.slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::I32)?)?;
let wq = wq
.slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::U32)?)?
.to_dtype(DType::I32)?;
let step = (wq.dims()[0] as f64 / 10.) as usize;

let a = wq.narrow(0, 0, step)?;
Expand Down
39 changes: 33 additions & 6 deletions mistralrs-quant/src/utils/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,24 @@ impl CustomOp2 for BitWiseOr {
let result = CpuStorage::U8(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "bitwise-or")),
CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or")),
CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or")),
CpuStorage::I16(vs1) => {
let vs2 = &s2.as_slice::<i16>().unwrap();
let result = self.bitwise(vs1, vs2);
let result = CpuStorage::I16(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::U32(vs1) => {
let vs2 = &s2.as_slice::<u32>().unwrap();
let result = self.bitwise(vs1, vs2);
let result = CpuStorage::U32(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I64(vs1) => {
let vs2 = &s2.as_slice::<i64>().unwrap();
let result = self.bitwise(vs1, vs2);
let result = CpuStorage::I64(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I32(vs1) => {
let vs2 = &s2.as_slice::<i32>().unwrap();
let result = self.bitwise(vs1, vs2);
Expand Down Expand Up @@ -284,9 +299,21 @@ impl CustomOp1 for Leftshift {
let result = CpuStorage::U8(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "leftshifr")),
CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshifr")),
CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshifr")),
CpuStorage::I16(vs1) => {
let result = self.leftshift(vs1);
let result = CpuStorage::I16(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::U32(vs1) => {
let result = self.leftshift(vs1);
let result = CpuStorage::U32(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I64(vs1) => {
let result = self.leftshift(vs1);
let result = CpuStorage::I64(result);
Ok((result, l1.shape().clone()))
}
CpuStorage::I32(vs1) => {
let result = self.leftshift(vs1);
let result = CpuStorage::I32(result);
Expand Down
1 change: 0 additions & 1 deletion mistralrs/examples/llguidance/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ async fn main() -> Result<()> {
.set_constraint(mistralrs::Constraint::Llguidance(LlguidanceGrammar {
grammars: vec![top, schema],
max_tokens: None,
test_trace: false,
}))
.set_sampler_max_len(100)
.add_message(
Expand Down
Loading