diff --git a/Cargo.lock b/Cargo.lock index f34ac6b251..bc3416cdf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2426,8 +2426,7 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "llguidance" version = "0.7.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b36091de5db3301cf2a2a16fec023b69dce306d2fdc9ba6a8a1fe9b404d6ad37" +source = "git+https://github.com/guidance-ai/llguidance.git?rev=2ce5ab8#2ce5ab8196f16dd8beba5a3d874eb1ab74e0268c" dependencies = [ "anyhow", "derivre", @@ -4830,8 +4829,7 @@ dependencies = [ [[package]] name = "toktrie" version = "0.7.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24abc046cdf691cae38efcc45a52e68e03ab26954157c6c15e2e7c9f6e46fef4" +source = "git+https://github.com/guidance-ai/llguidance.git?rev=2ce5ab8#2ce5ab8196f16dd8beba5a3d874eb1ab74e0268c" dependencies = [ "anyhow", "bytemuck", @@ -4843,8 +4841,7 @@ dependencies = [ [[package]] name = "toktrie_hf_tokenizers" version = "0.7.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200f5b166ecb572393f0fdacabc7904a6f5f9bc2766e784e4d2b0c7a80e2bed9" +source = "git+https://github.com/guidance-ai/llguidance.git?rev=2ce5ab8#2ce5ab8196f16dd8beba5a3d874eb1ab74e0268c" dependencies = [ "anyhow", "log", diff --git a/Cargo.toml b/Cargo.toml index 3831dd0362..e7611bc183 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -125,8 +125,8 @@ schemars = "0.8.22" serde_yaml = "0.9.34" serde_plain = "1.0.2" as-any = "0.3.2" -llguidance = { version = "0.7.29", default-features = false, features = ["lark"] } -toktrie_hf_tokenizers = "0.7.29" +llguidance = { git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", default-features = false, features = ["lark"], rev = "2ce5ab8" } +toktrie_hf_tokenizers = {git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", rev = "2ce5ab8" } objc = { version = "0.2.7" } serde-big-array = "0.5.1" interprocess = "2.2.3" diff --git a/README.md b/README.md index d35b7ad9d1..ba05ab204c 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,7 @@ If you do not specify the architecture, an attempt will be made to use the model - `phi3.5moe` - `qwen2` - `gemma2` +- `glm4` - `starcoder2` - `deepseekv2` - `deepseekv3` @@ -426,6 +427,7 @@ If you do not specify the architecture, an attempt will be made to use the model - phi3 - starcoder2 - qwen2 +- qwen3 **With adapters:** - llama @@ -456,6 +458,7 @@ Please submit more benchmarks via raising an issue! |Phi 3 Vision| | |✅| |Idefics 2| | |✅| |Gemma 2| | |✅| +|GLM4| | |✅| |Starcoder 2| |✅|✅| |LLaVa Next| | |✅| |LLaVa| | |✅| @@ -469,7 +472,7 @@ Please submit more benchmarks via raising an issue! |Gemma 3| | |✅| |Mistral 3| | |✅| |Llama 4| | |✅| -|Qwen 3| | |✅| +|Qwen 3|✅| |✅| |Dia 1.6b| | |✅| @@ -502,6 +505,7 @@ Please submit more benchmarks via raising an issue! |Phi 3 Vision| | | | |Idefics 2| | | | |Gemma 2|✅| | | +|GLM4|✅| | | |Starcoder 2|✅| | | |LLaVa Next| | | | |LLaVa| | | | diff --git a/docs/GLM4.md b/docs/GLM4.md new file mode 100644 index 0000000000..eb6aac7ae3 --- /dev/null +++ b/docs/GLM4.md @@ -0,0 +1,59 @@ +# GLM4 Model + +**[See the GLM4 model Collection](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e)** + +GLM4 is a series of open, multilingual, and multimodal large language models. The text-to-text LLM backbones in GLM4 are supported by mistral.rs. + +## HTTP API + +```py +import openai + +messages = [] +prompt = input("Enter system prompt >>> ") +if len(prompt) > 0: + messages.append({"role": "system", "content": prompt}) + + +while True: + prompt = input(">>> ") + messages.append({"role": "user", "content": prompt}) + completion = client.chat.completions.create( + model="glm4", + messages=messages, + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, + ) + resp = completion.choices[0].message.content + print(resp) + messages.append({"role": "assistant", "content": resp}) +``` + +## Python API +```py +from mistralrs import Runner, Which, ChatCompletionRequest, Architecture + +runner = Runner( + which=Which.Plain( + model_id="THUDM/GLM-4-9B-0414", + arch=Architecture.GLM4, + ), +) + +res = runner.send_chat_completion_request( + ChatCompletionRequest( + model="glm4", + messages=[ + {"role": "user", "content": "Tell me a story about the Rust type system."} + ], + max_tokens=256, + presence_penalty=1.0, + top_p=0.1, + temperature=0.1, + ) +) +print(res.choices[0].message.content) +print(res.usage) +``` \ No newline at end of file diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 0fc0865979..8754b7b062 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -2127,6 +2127,42 @@ impl Mlp { }) } + pub fn new_merged( + vb: ShardedVarBuilder, + hidden_size: usize, + intermediate_size: usize, + chunks: usize, + quantization_config: &Option, + hidden_act: Activation, + comm: &Arc, + ) -> Result { + assert!(chunks == 2, "Only gate_up_proj merge is supported!"); + let gate_up_projs = ColumnParallelLayer::new_merged( + hidden_size, + intermediate_size * 2, + 2, + quantization_config, + false, + comm, + vb.pp("gate_up_proj"), + )?; + + Ok(Self { + gate: gate_up_projs[0].to_owned(), + up: gate_up_projs[1].to_owned(), + down: RowParallelLayer::new( + intermediate_size, + hidden_size, + quantization_config, + false, + comm, + vb.pp("down_proj"), + )?, + act: hidden_act, + params: vec![hidden_size, intermediate_size], + }) + } + pub fn replicate( params: &[usize], vb: ShardedVarBuilder, diff --git a/mistralrs-core/src/models/glm4.rs b/mistralrs-core/src/models/glm4.rs new file mode 100644 index 0000000000..fab2f5b52c --- /dev/null +++ b/mistralrs-core/src/models/glm4.rs @@ -0,0 +1,907 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use crate::{ + amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp}, + attention::SdpaParams, + device_map::DeviceMapper, + get_delta_from_lora_ab, + layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, Sdpa}, + layers_masker::PastKvLenCache, + paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, + pipeline::{ + extract_logits, + text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, + EitherCache, IsqModel, KvCache, NormalCache, NormalCacheType, NormalLoadingMetadata, + NormalModel, + }, + serde_default_fn, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, +}; +use candle_core::IndexOp; +use candle_core::{DType, Device, Module, Result, Tensor, D}; +use mistralrs_quant::{ + ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer, + ShardedVarBuilder, +}; +use serde::{Deserialize, Serialize}; +use std::iter::zip; +use std::{collections::HashMap, sync::Arc}; + +serde_default_fn!(bool, tie_word_embeddings, false); +serde_default_fn!(usize, max_position_embeddings, 32768); + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) hidden_act: Activation, + pub(crate) rms_norm_eps: f64, + pub(crate) rope_theta: f64, + pub(crate) sliding_window: Option, + pub(crate) partial_rotary_factor: Option, + #[serde(default = "max_position_embeddings")] + pub(crate) max_position_embeddings: usize, + pub(crate) attention_bias: Option, + pub(crate) head_dim: Option, + pub(crate) quantization_config: Option, + #[serde(default = "tie_word_embeddings")] + pub(crate) tie_word_embeddings: bool, +} + +impl Config { + pub(crate) fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } +} + +struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, + rotary_dim: usize, +} + +impl RotaryEmbedding { + fn new( + rope_theta: f32, + partial_rotary_factor: Option, + head_dim: usize, + max_seq_len: usize, + dev: &Device, + dtype: DType, + ) -> Result { + let mut rotary_dim = head_dim; + if let Some(factor) = partial_rotary_factor { + rotary_dim = (factor * head_dim as f32) as usize; + }; + + let inv_freq: Vec<_> = (0..rotary_dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f32 / rotary_dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + rotary_dim, + }) + } + + fn apply_rotary_emb(&self, xs: &Tensor, input_positions: &[usize]) -> Result { + let (b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; + let mut embeds = Vec::new(); + for (b, seqlen_offset) in zip(0..b_size, input_positions) { + let (s, e) = (*seqlen_offset, *seqlen_offset + seq_len); + let cos = self.cos.i((s..e, ..))?.contiguous()?; + let sin = self.sin.i((s..e, ..))?.contiguous()?; + let xs_rot = xs + .i((b, .., .., ..self.rotary_dim))? + .unsqueeze(0)? + .contiguous()?; + let xs_pass = xs.i((b, .., .., self.rotary_dim..))?.unsqueeze(0)?; + let xs_rot = candle_nn::rotary_emb::rope_i(&xs_rot, &cos, &sin).unwrap(); + let embed = Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?; + embeds.push(embed); + } + Tensor::cat(&embeds, 0) + } +} + +struct Attention { + q_proj: Arc, + k_proj: Arc, + v_proj: Arc, + o_proj: Arc, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + rotary_emb: Arc, + paged_attn: Option, + sdpa_params: SdpaParams, +} + +impl Attention { + #[allow(clippy::too_many_arguments)] + fn new( + rotary_emb: Arc, + cfg: &Config, + vb: ShardedVarBuilder, + mapper: &dyn DeviceMapper, + layer_idx: usize, + loading_isq: bool, + paged_attn: Option, + comm: &Arc, + ) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let head_dim = cfg.head_dim(); + let q_proj = ColumnParallelLayer::new( + hidden_sz, + num_heads * head_dim, + &cfg.quantization_config, + cfg.attention_bias.unwrap_or(false), + comm, + mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq), + )?; + let kv_shard = mistralrs_quant::compute_kv_shard( + cfg.num_key_value_heads, + cfg.hidden_size / cfg.num_attention_heads, + comm, + ); + let k_proj = ColumnParallelLayer::new_with_shard( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + cfg.attention_bias.unwrap_or(false), + comm, + kv_shard, + mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq), + )?; + let v_proj = ColumnParallelLayer::new_with_shard( + hidden_sz, + num_kv_heads * head_dim, + &cfg.quantization_config, + cfg.attention_bias.unwrap_or(false), + comm, + kv_shard, + mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq), + )?; + let o_proj = RowParallelLayer::new( + num_heads * head_dim, + hidden_sz, + &cfg.quantization_config, + false, + comm, + mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq), + )?; + + assert!(cfg.num_attention_heads >= comm.world_size()); + assert!(cfg.num_attention_heads % comm.world_size() == 0); + + assert!(cfg.num_key_value_heads >= comm.world_size()); + assert!(cfg.num_key_value_heads % comm.world_size() == 0); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads: num_heads / comm.world_size(), + num_kv_heads: (num_kv_heads / comm.world_size()).max(1), + head_dim, + rotary_emb, + paged_attn, + sdpa_params: SdpaParams { + n_kv_groups: mistralrs_quant::compute_n_kv_groups( + cfg.num_key_value_heads, + cfg.num_attention_heads, + comm, + ), + softcap: None, + softmax_scale: 1.0 / (head_dim as f32).sqrt(), + sliding_window: cfg.sliding_window, + }, + }) + } + + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offsets: &[usize], + kv_cache: &mut KvCache, + metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let original_dtype = xs.dtype(); + let mut xs = xs.clone(); + if let Some(t) = self.q_proj.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?; + let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?; + let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?; + + (q, k, v) = if q_len != 1 { + let q = q + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + (q, k, v) + } else { + let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?; + let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?; + let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?; + (q, k, v) + }; + + q = self.rotary_emb.apply_rotary_emb(&q, seqlen_offsets)?; + k = self.rotary_emb.apply_rotary_emb(&k, seqlen_offsets)?; + + if self.q_proj.quantized_act_type().is_some() { + q = q.to_dtype(original_dtype)?; + k = k.to_dtype(original_dtype)?; + v = v.to_dtype(original_dtype)?; + } + + let mut attn_output = match &self.paged_attn { + Some(paged_attn) => match metadata { + Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward( + &q, + &k, + &v, + attention_mask, + Some(key_cache), + Some(value_cache), + input_metadata, + &self.sdpa_params, + Some(flash_params), + )?, + None => { + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); + paged_attn.forward( + &q, + &k, + &v, + attention_mask, + None, + None, + &input_metadata, + &self.sdpa_params, + Some(flash_params), + )? + } + }, + None => { + let (k, v) = kv_cache.append(&k, &v)?; + + Sdpa.run_attention( + &q, + &k, + &v, + attention_mask, + Some(flash_params), + &self.sdpa_params, + )? + } + }; + + if let Some(t) = self.q_proj.quantized_act_type() { + attn_output = attn_output.to_dtype(t)?; + } + attn_output = if attention_mask.is_some() { + attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? + } else { + attn_output.reshape((b_sz, q_len, ()))? + }; + let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?; + if self.q_proj.quantized_act_type().is_some() { + res = res.to_dtype(original_dtype)?; + } + Ok(res) + } +} + +struct DecoderLayer { + self_attn: Attention, + mlp: Box, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, + post_mlp_layernorm: RmsNorm, + post_self_attn_layernorm: RmsNorm, +} + +impl DecoderLayer { + #[allow(clippy::too_many_arguments)] + fn new( + rotary_emb: Arc, + cfg: &Config, + vb: ShardedVarBuilder, + mapper: &dyn DeviceMapper, + layer_idx: usize, + loading_isq: bool, + paged_attn: Option, + comm: &Arc, + ) -> Result { + let self_attn = Attention::new( + rotary_emb, + cfg, + mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq), + mapper, + layer_idx, + loading_isq, + paged_attn, + comm, + )?; + let mlp = Mlp::new_merged( + mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq), + cfg.hidden_size, + cfg.intermediate_size, + 2, + &cfg.quantization_config, + cfg.hidden_act, + comm, + )?; + let input_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("input_layernorm"), false), + )?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false), + )?; + + let post_self_attn_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("post_self_attn_layernorm"), false), + )?; + let post_mlp_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_device(layer_idx, vb.pp("post_mlp_layernorm"), false), + )?; + + Ok(Self { + self_attn, + mlp: Box::new(mlp), + input_layernorm, + post_attention_layernorm, + post_self_attn_layernorm, + post_mlp_layernorm, + }) + } + + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offsets: &[usize], + kv_cache: &mut KvCache, + metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + let residual = xs; + let hidden_states = self.input_layernorm.forward(xs)?; + let hidden_states = self.self_attn.forward( + &hidden_states, + attention_mask, + seqlen_offsets, + kv_cache, + metadata, + flash_params, + )?; + let hidden_states = self.post_self_attn_layernorm.forward(&hidden_states)?; + let hidden_states = (residual + hidden_states)?; + let residual = &hidden_states; + let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; + let hidden_states = self.mlp.forward(&hidden_states)?; + let hidden_states = self.post_mlp_layernorm.forward(&hidden_states)?; + residual + hidden_states + } +} + +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Arc, + sliding_window: Option, + device: Device, + cache: EitherCache, + max_seq_len: usize, + mapper: Box, + cfg: ModelConfigMetadata, +} + +impl Model { + pub fn new( + cfg: &Config, + vb: ShardedVarBuilder, + is_gptx: bool, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result { + let vb_m = vb.pp("model"); + let vb_lm_head = vb.pp("lm_head"); + Self::new_inner( + cfg, + vb_m, + vb_lm_head, + is_gptx, + normal_loading_metadata, + attention_mechanism, + ) + } + + pub fn new_inner( + cfg: &Config, + vb_m: ShardedVarBuilder, + vb_lm_head: ShardedVarBuilder, + _is_gptx: bool, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization: {}.", + quant_cfg.name(), + quant_cfg.get_bits_name(&vb_m) + ); + } + let mapper = normal_loading_metadata.mapper; + + let embed_tokens = embedding( + cfg.vocab_size, + cfg.hidden_size, + mapper.set_nm_device(vb_m.pp("embed_tokens"), false), + &cfg.quantization_config, + )?; + + let head_dim = cfg.head_dim(); + let mut ropes = HashMap::new(); + for layer_idx in 0..cfg.num_hidden_layers { + let device = mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device); + ropes.insert( + device.location(), + Arc::new(RotaryEmbedding::new( + cfg.rope_theta as f32, + cfg.partial_rotary_factor, + head_dim, + cfg.max_position_embeddings, + device, + if normal_loading_metadata.loading_isq { + DType::F32 + } else { + vb_m.dtype() + }, + )?), + ); + } + + let vb_l = vb_m.pp("layers"); + let layers = NiceProgressBar::<_, 'b'>( + 0..cfg.num_hidden_layers, + "Loading repeating layers", + &normal_loading_metadata.multi_progress, + ) + .par_iter_if_isq(|layer_idx| -> Result { + let device = mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device); + let rotary_emb = ropes + .get(&device.location()) + .expect("No RoPE for device location!") + .clone(); + let paged_attn = match &attention_mechanism { + AttentionImplementation::Eager => None, + AttentionImplementation::PagedAttention => { + Some(PagedAttention::new(head_dim, device, None)?) + } + }; + let comm = mapper.get_comm_for(layer_idx)?; + DecoderLayer::new( + rotary_emb.clone(), + cfg, + vb_l.pp(layer_idx), + &*mapper, + layer_idx, + normal_loading_metadata.loading_isq, + paged_attn, + &comm, + ) + })?; + let norm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + mapper.set_nm_device(vb_m.pp("norm"), false), + )?; + let lm_head = if !cfg.tie_word_embeddings { + ReplicatedLayer::new( + cfg.hidden_size, + cfg.vocab_size, + &cfg.quantization_config, + false, + mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq), + )? + } else { + ReplicatedLayer::from_linear(candle_nn::Linear::new( + mapper.cast_nm_device( + embed_tokens.embeddings(), + normal_loading_metadata.loading_isq, + )?, + None, + ))? + }; + let cache_types = (0..cfg.num_hidden_layers) + .map(|_| { + cfg.sliding_window + .map(|window| NormalCacheType::SlidingWindow { window }) + .unwrap_or(NormalCacheType::Normal { + max_seq_len: cfg.max_position_embeddings, + }) + }) + .collect::>(); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: normal_loading_metadata.real_device, + cache: EitherCache::Normal(NormalCache::from_types(cache_types)), + max_seq_len: cfg.max_position_embeddings, + cfg: ModelConfigMetadata { + max_seq_len: cfg.max_position_embeddings, + num_layers: cfg.num_hidden_layers, + hidden_size: cfg.hidden_size, + num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size()) + .max(1), + num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(), + sliding_window: cfg.sliding_window, + k_head_dim: cfg.head_dim(), + v_head_dim: cfg.head_dim(), + }, + mapper, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + seqlen_offsets: &[usize], + context_lens: Vec<(usize, usize)>, + metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + self.forward_embeds( + input_ids, + self.embed_tokens.forward(input_ids)?, + seqlen_offsets, + context_lens, + metadata, + flash_params, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn forward_embeds( + &self, + input_ids: &Tensor, + input_embeds: Tensor, + seqlen_offsets: &[usize], + context_lens: Vec<(usize, usize)>, + metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + let mut xs = input_embeds; + let cache = &mut self.cache.normal().0; + let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix( + input_ids, + metadata + .as_ref() + .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache) + .unwrap_or(cache as &dyn PastKvLenCache), + self.sliding_window, + xs.dtype(), + self.cfg.num_attn_heads, + )?; + // PagedAttention prompt chunking + let attention_mask = attention_mask.filter(|_| { + metadata + .as_ref() + .map(|(_, meta)| meta.is_first_prompt_chunk) + .unwrap_or(true) + }); + for (i, layer) in self.layers.iter().enumerate() { + xs = self.mapper.map(xs, i)?; + xs = layer.forward( + &xs, + attention_mask + .as_ref() + .map(|m| m.to_device(xs.device()).unwrap()) + .as_ref(), + seqlen_offsets, + &mut cache[i], + metadata + .as_ref() + .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)), + flash_params, + )?; + } + let xs = xs.to_device(&self.device)?; + let mut xs = xs.apply(&self.norm)?; + if let Some(t) = self.lm_head.quantized_act_type() { + xs = xs.to_dtype(t)?; + } + extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens) + } +} + +impl IsqModel for Model { + fn get_layers( + &mut self, + ) -> ( + Vec<(&mut Arc, Option)>, + &dyn DeviceMapper, + ) { + let mut tensors = Vec::new(); + tensors.push((&mut self.lm_head, None)); + for (i, layer) in self.layers.iter_mut().enumerate() { + tensors.push((&mut layer.self_attn.q_proj, Some(i))); + tensors.push((&mut layer.self_attn.k_proj, Some(i))); + tensors.push((&mut layer.self_attn.v_proj, Some(i))); + tensors.push((&mut layer.self_attn.o_proj, Some(i))); + tensors.extend( + layer + .mlp + .get_isq_layers() + .into_iter() + .map(|m| (m, Some(i))) + .collect::>(), + ); + } + (tensors, &*self.mapper) + } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + uvb_l + .pp("post_self_attn_layernorm") + .add(&layer.post_self_attn_layernorm); + uvb_l + .pp("post_mlp_layernorm") + .add(&layer.post_mlp_layernorm); + } + + uvb.to_safetensors() + } + + fn imatrix_names(&self) -> candle_core::Result>> { + // NOTE: dependant on the exact implementation in get_layers! + let mut names = Vec::new(); + // lm_head + names.push(None); + for i in 0..self.layers.len() { + names.push(Some(format!("blk.{i}.attn_q.weight"))); + names.push(Some(format!("blk.{i}.attn_k.weight"))); + names.push(Some(format!("blk.{i}.attn_v.weight"))); + names.push(Some(format!("blk.{i}.attn_output.weight"))); + names.push(Some(format!("blk.{i}.ffn_gate.weight"))); + names.push(Some(format!("blk.{i}.ffn_up.weight"))); + names.push(Some(format!("blk.{i}.ffn_down.weight"))); + } + Ok(names) + } +} + +impl NormalModel for Model { + fn forward( + &self, + input_ids: &Tensor, + seqlen_offsets: &[usize], + context_lens: Vec<(usize, usize)>, + _position_ids: Vec, + metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>, + flash_params: &FlashParams, + ) -> Result { + self.forward( + input_ids, + seqlen_offsets, + context_lens, + metadata, + flash_params, + ) + } + fn xlora_forward( + &self, + _input_ids: &Tensor, + _input_ids_full: &Tensor, + _seqlen_offsets: &[usize], + _seqlen_offsets_full: &[usize], + _no_kv_cache: bool, + _non_granular_state: &Option, + _context_lens: Vec<(usize, usize)>, + _position_ids: Vec, + _flash_params: &FlashParams, + _flash_params_full: &FlashParams, + ) -> Result { + unimplemented!() + } + fn cache(&self) -> &EitherCache { + &self.cache + } + fn cache_mut(&mut self) -> &mut EitherCache { + &mut self.cache + } + fn device(&self) -> &Device { + &self.device + } + fn is_xlora(&self) -> bool { + false + } + fn max_seq_len(&self) -> usize { + self.max_seq_len + } + fn config(&self) -> &ModelConfigMetadata { + &self.cfg + } +} + +impl AnyMoeBaseModelMixin for Model { + fn get_mlps(&self) -> Vec<&dyn MlpLayer> { + let mut mlps = Vec::new(); + for layer in &self.layers { + mlps.push(&*layer.mlp); + } + mlps + } + fn get_mlps_mut(&mut self) -> Vec<&mut Box> { + let mut mlps = Vec::new(); + for layer in &mut self.layers { + mlps.push(&mut layer.mlp); + } + mlps + } + fn create_anymoe_layers( + &mut self, + additional_vbs: Vec, + config: AnyMoeConfig, + (prefix, mlp): (String, String), + mut layers: Vec, + expert_type: AnyMoeExpertType, + gate_vb: Option, + ) -> Result<()> { + let mut experts: Vec>> = Vec::new(); + if layers.is_empty() { + layers = (0..self.layers.len()).collect::>(); + } + for _ in 0..layers.len() { + experts.push(Vec::new()); + } + for vb in additional_vbs { + let vb = vb.pp(&prefix); + for (layer, row) in experts.iter_mut().enumerate() { + if !layers.contains(&layer) { + continue; + } + + let intermediate_size = self.layers[layer].mlp.get_params()[1]; + let hidden_size = self.layers[layer].mlp.get_params()[0]; + match expert_type { + AnyMoeExpertType::FineTuned => { + let (dtype, device) = self.layers[layer].mlp.dtype_device(); + row.push(Box::new(Mlp::replicate( + self.layers[layer].mlp.get_params(), + vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device), + self.layers[layer].mlp.hidden_act(), + &self.mapper.get_comm_for(layer)?, + )?)); + } + AnyMoeExpertType::LoraAdapter { + rank, + alpha, + ref target_modules, + } => { + let vb_mlp = vb.pp(layer).pp(&mlp); + + let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (hidden_size, intermediate_size), + "gate_proj" + )) + } else { + None + }; + let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (hidden_size, intermediate_size), + "up_proj" + )) + } else { + None + }; + let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) { + Some(get_delta_from_lora_ab!( + vb_mlp, + rank, + alpha, + (intermediate_size, hidden_size), + "down_proj" + )) + } else { + None + }; + + row.push(self.layers[layer].mlp.new_added_delta(vec![ + gate_proj_delta, + up_proj_delta, + down_proj_delta, + ])?); + } + } + } + } + for (layer, expert) in layers.into_iter().zip(experts) { + let mut experts_all = vec![self.layers[layer].mlp.clone()]; + experts_all.extend(expert); + let (dtype, device) = self.layers[layer].mlp.dtype_device(); + self.layers[layer].mlp = Box::new(MoeMlp::new( + experts_all, + config.clone(), + dtype, + &device, + layer, + gate_vb.as_ref(), + )?); + } + Ok(()) + } + fn amoe_supported(&self) -> bool { + true + } +} diff --git a/mistralrs-core/src/models/mod.rs b/mistralrs-core/src/models/mod.rs index d3eefb3c02..5d63e82c0b 100644 --- a/mistralrs-core/src/models/mod.rs +++ b/mistralrs-core/src/models/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod deepseek2; pub(crate) mod deepseek3; pub(crate) mod gemma; pub(crate) mod gemma2; +pub(crate) mod glm4; pub(crate) mod llama; pub(crate) mod mistral; pub(crate) mod mixtral; diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs index f3c9bada0d..904e9e4e66 100644 --- a/mistralrs-core/src/pipeline/chat_template.rs +++ b/mistralrs-core/src/pipeline/chat_template.rs @@ -292,7 +292,13 @@ pub fn apply_chat_template_to( template } }; - let template = template.replace("[::-1]", "|reverse"); + let mut template = template.replace("[::-1]", "|reverse"); + + if template.contains("{{ meta }}") { + //fix for GLM4 models + template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", ""); + template = template.replace("{{ meta }}", ""); + } env.add_template("chat_template", &template)?; env.add_function("raise_exception", raise_exception); diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs index c12d884360..76f574ae75 100644 --- a/mistralrs-core/src/pipeline/isq.rs +++ b/mistralrs-core/src/pipeline/isq.rs @@ -29,7 +29,7 @@ impl<'a> CowBytesView<'a> { } } -impl<'a> safetensors::tensor::View for CowBytesView<'a> { +impl safetensors::tensor::View for CowBytesView<'_> { fn dtype(&self) -> safetensors::tensor::Dtype { // Serialize as raw bytes safetensors::tensor::Dtype::U8 diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 3993c513c7..ffdf8b634c 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -20,10 +20,10 @@ use serde::Deserialize; use tokio::sync::Mutex; pub use normal_loaders::{ - AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader, - MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel, - NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, - Qwen3MoELoader, Starcoder2Loader, + AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader, + LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, + NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, + Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, }; pub use vision_loaders::{ diff --git a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs index a8fd695393..4592a8f721 100644 --- a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs @@ -165,6 +165,8 @@ pub enum NormalLoaderType { DeepSeekV3, #[serde(rename = "qwen3")] Qwen3, + #[serde(rename = "glm4")] + GLM4, #[serde(rename = "qwen3moe")] Qwen3Moe, } @@ -186,6 +188,7 @@ impl NormalLoaderType { "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2), "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3), "Qwen3ForCausalLM" => Ok(Self::Qwen3), + "Glm4ForCausalLM" => Ok(Self::GLM4), "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe), other => anyhow::bail!( "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue." @@ -210,9 +213,10 @@ impl FromStr for NormalLoaderType { "phi3.5moe" => Ok(Self::Phi3_5MoE), "deepseekv2" => Ok(Self::DeepSeekV2), "deepseekv3" => Ok(Self::DeepSeekV3), - "qwen3" => Ok(Self::DeepSeekV3), + "qwen3" => Ok(Self::Qwen3), + "glm4" => Ok(Self::GLM4), "qwen3moe" => Ok(Self::Qwen3Moe), - a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `qwen3moe`.")), + a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`.")), } } } @@ -233,6 +237,7 @@ impl Display for NormalLoaderType { Self::DeepSeekV2 => write!(f, "deepseekv2"), Self::DeepSeekV3 => write!(f, "deepseekv3"), Self::Qwen3 => write!(f, "qwen3"), + Self::GLM4 => write!(f, "glm4"), Self::Qwen3Moe => write!(f, "qwen3moe"), } } @@ -283,6 +288,7 @@ impl AutoNormalLoader { NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)), NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)), NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)), + NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)), NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)), } } @@ -3138,6 +3144,183 @@ impl DeviceMappedModelLoader for Qwen3Loader { } } +/// [`NormalLoader`] for a GLM 4 model. +/// +/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html +pub struct GLM4Loader; + +impl NormalModelLoader for GLM4Loader { + fn load( + &self, + config: &str, + vb: ShardedVarBuilder, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result> { + let cfg: crate::models::glm4::Config = serde_json::from_str(config)?; + + Ok(Box::new(models::glm4::Model::new( + &cfg, + vb, + self.is_gptx(config)?, + normal_loading_metadata, + attention_mechanism, + )?)) + } + fn load_xlora( + &self, + _config: &str, + _vb: ShardedVarBuilder, + _lora_config: &[((String, String), LoraConfig)], + _xlora_config: Option, + _xlora_ordering: Ordering, + _normal_loading_metadata: NormalLoadingMetadata, + _preload_adapters: &Option>, + ) -> Result> { + todo!() + } + fn is_gptx(&self, _: &str) -> Result { + Ok(true) + } + fn get_config_repr(&self, config: &str) -> Result> { + let cfg: crate::models::glm4::Config = serde_json::from_str(config)?; + + Ok(Box::new(cfg)) + } +} + +impl IsqModelLoader for GLM4Loader { + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) + } + fn immediate_isq_predicates(&self, config: &str) -> Result> { + self.isq_layer_regexes(config) + } +} + +impl DeviceMappedModelLoader for GLM4Loader { + fn mapped_max_act_size_elems( + &self, + config: &str, + params: &AutoDeviceMapParams, + prompt_chunksize: usize, + ) -> Result { + let AutoDeviceMapParams::Text { + max_seq_len: _, + max_batch_size, + } = params + else { + anyhow::bail!("Expected text AutoDeviceMapParams for this model!") + }; + + let cfg: models::glm4::Config = serde_json::from_str(config)?; + + Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize) + } + fn non_mapped_max_act_size_elems( + &self, + _config: &str, + _params: &AutoDeviceMapParams, + ) -> Result { + Ok(0) + } + + fn non_mapped_size_in_bytes( + &self, + config: &str, + dtype: DType, + weight_pack_factor: usize, + ) -> Result { + let cfg: models::glm4::Config = serde_json::from_str(config)?; + let elems = { + let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor; + // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed + let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 { + cfg.hidden_size * cfg.vocab_size / weight_pack_factor + } else { + 0 + }; + let norm = cfg.hidden_size; + embed_tokens + lm_head + norm + }; + Ok(elems * dtype.size_in_bytes()) + } + + fn layer_sizes_in_bytes( + &self, + config: &str, + dtype: DType, + weight_pack_factor: usize, + ) -> Result> { + let cfg: models::glm4::Config = serde_json::from_str(config)?; + let per_layer_elems = { + let input_layernorm = cfg.hidden_size; + let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm + + let size_in = cfg.hidden_size; + let size_q = cfg.head_dim() * cfg.num_attention_heads; + let size_kv = cfg.head_dim() * cfg.num_key_value_heads; + let q_proj = size_in * size_q / weight_pack_factor + size_q; + let k_proj = size_in * size_kv / weight_pack_factor + size_kv; + let v_proj = size_in * size_kv / weight_pack_factor + size_kv; + let o_proj = size_q * size_in / weight_pack_factor; + + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let gate_proj = h_size * i_size / weight_pack_factor; + let up_proj = h_size * i_size / weight_pack_factor; + let down_proj = i_size * h_size / weight_pack_factor; + + input_layernorm + + post_attention_layernorm + + q_proj + + k_proj + + v_proj + + o_proj + + gate_proj + + up_proj + + down_proj + }; + Ok(vec![ + per_layer_elems * dtype.size_in_bytes(); + cfg.num_hidden_layers + ]) + } + + fn num_layers(&self, config: &str) -> Result { + let cfg: models::glm4::Config = serde_json::from_str(config)?; + Ok(cfg.num_hidden_layers) + } + + fn model_config(&self, config: &str) -> Result> { + let cfg: models::glm4::Config = serde_json::from_str(config)?; + + let cfg = ModelConfigMetadata { + max_seq_len: cfg.max_position_embeddings, + num_layers: cfg.num_hidden_layers, + hidden_size: cfg.hidden_size, + num_kv_heads: cfg.num_key_value_heads, + num_attn_heads: cfg.num_attention_heads, + sliding_window: cfg.sliding_window, + k_head_dim: cfg.hidden_size / cfg.num_attention_heads, + v_head_dim: cfg.hidden_size / cfg.num_attention_heads, + }; + + Ok(Box::new(cfg)) + } +} + /// [`NormalLoader`] for a Qwen 3 MoE model. /// /// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index a69093a362..38b958860b 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -37,13 +37,14 @@ use llguidance::toktrie::TokEnv; pub use loaders::{ AdapterKind, AutoDeviceMapParams, AutoNormalLoader, AutoVisionLoader, DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel, - DiffusionModelLoader, FluxLoader, Gemma2Loader, Gemma3Loader, GemmaLoader, Idefics2Loader, - Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, - MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind, ModelPaths, - NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader, Phi2Loader, - Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName, QuantizationKind, - Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, - TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader, + DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, GemmaLoader, + Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, + LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind, + ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader, + Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName, + QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader, + Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel, + VisionModelLoader, }; use mistralrs_quant::IsqType; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 1a29421c14..e4f4870bb3 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -11,9 +11,9 @@ use super::{ IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin, }; use super::{ - AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader, - MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader, - Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, + AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader, + LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, + Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, }; use crate::amoe::AnyMoeExpertType; use crate::device_map::{self, DeviceMapper}; @@ -222,6 +222,7 @@ impl NormalLoaderBuilder { Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader), Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader), Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader), + Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader), Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader), None => Box::new(AutoNormalLoader), }; @@ -650,9 +651,14 @@ impl Loader for NormalLoader { }; let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?; - let gen_conf: Option = paths.get_gen_conf_filename().map(|f| { - serde_json::from_str(&fs::read_to_string(f).unwrap()) - .expect("bos_token_id/eos_token_id missing in generation_config.json") + let gen_conf: Option = paths.get_gen_conf_filename().and_then(|f| { + match serde_json::from_str::(&fs::read_to_string(f).unwrap()) { + Ok(conf) => Some(conf), + Err(e) => { + warn!("Failed to parse generation_config.json: {}", e); + None + } + } }); let chat_template_explicit = paths diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index f584124215..4b6f1c368a 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -22,6 +22,7 @@ If you do not specify the architecture, an attempt will be made to use the model - `Phi3` - `Qwen2` - `Gemma2` +- `GLM4` - `Starcoder2` - `Phi3_5MoE` - `DeepseekV2` diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index 778f7e0080..9c56b31c27 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -104,6 +104,7 @@ class Architecture(Enum): DeepseekV2 = "deepseekv2" DeepseekV3 = "deepseekv3" Qwen3 = "qwen3" + GLM4 = "glm4" Qwen3Moe = "qwen3moe" @dataclass diff --git a/mistralrs-pyo3/src/which.rs b/mistralrs-pyo3/src/which.rs index a878aeef7b..43c090714f 100644 --- a/mistralrs-pyo3/src/which.rs +++ b/mistralrs-pyo3/src/which.rs @@ -22,6 +22,7 @@ pub enum Architecture { DeepseekV2, DeepseekV3, Qwen3, + GLM4, Qwen3Moe, } @@ -41,6 +42,7 @@ impl From for NormalLoaderType { Architecture::DeepseekV2 => Self::DeepSeekV2, Architecture::DeepseekV3 => Self::DeepSeekV3, Architecture::Qwen3 => Self::Qwen3, + Architecture::GLM4 => Self::GLM4, Architecture::Qwen3Moe => Self::Qwen3Moe, } } diff --git a/mistralrs-quant/README.md b/mistralrs-quant/README.md index 4e39e5763d..52223fb86f 100644 --- a/mistralrs-quant/README.md +++ b/mistralrs-quant/README.md @@ -11,7 +11,7 @@ It has grown beyon simply quantization and is used by `mistral.rs` to power: Currently supported: - GGUF: `GgufMatMul`(2-8 bit quantization, with imatrix) -- Gptq: `GptqLayer`(with CUDA marlin kernel) +- Gptq/Awq: `GptqAwqLayer`(with CUDA marlin kernel) - Hqq: `HqqLayer` (4, 8 bit quantization) - FP8: `FP8Linear`(optimized on CUDA) - Unquantized (used for ISQ): `UnquantLinear` diff --git a/mistralrs-quant/src/distributed/layers.rs b/mistralrs-quant/src/distributed/layers.rs index 33c5ec98a7..d3f91275e9 100644 --- a/mistralrs-quant/src/distributed/layers.rs +++ b/mistralrs-quant/src/distributed/layers.rs @@ -338,6 +338,35 @@ impl ColumnParallelLayer { Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb) } + + pub fn new_merged( + in_dim: usize, + out_dim: usize, + chunks: usize, + config: &Option, + bias: bool, + comm: &Arc, + vb: ShardedVarBuilder, + ) -> Result>> { + let mut vec_layers = Vec::>::new(); + for chunk_idx in 0..chunks { + let layer = ColumnParallelLayer::new_with_shard( + in_dim, + out_dim, + config, + bias, + comm, + shard( + 0, + chunk_idx * comm.world_size() + comm.rank(), + chunks * comm.world_size(), + ), + vb.clone(), + )?; + vec_layers.push(layer); + } + Ok(vec_layers) + } } impl QuantMethod for ColumnParallelLayer {