diff --git a/Cargo.toml b/Cargo.toml index adfda6b9..ab7071fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,8 @@ categories = ["algorithms", "hardware-support", "science"] license = "MIT" [dependencies] -candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } -candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" } +candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } +candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" } serde = { version = "1.0.190", features = ["serde_derive"] } tokenizers = {version = "0.21.2", features = ["http"] } hf-hub = "0.4.1" @@ -44,7 +44,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "af0b475" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "29e4beb" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" @@ -60,7 +60,7 @@ utoipa = { version = "4.2", features = ["axum_extras"] } colored = { version = "3.0.0" } tower-http = { version = "0.6.6", features = ["cors"] } rustchatui = { git = "https://github.com/guoqingbao/rustchatui.git", rev = "68caad9" } -sysinfo = "0.37.2" +sysinfo = "0.38.3" image = { version = "0.25.6", default-features = false, features = ['bmp', 'gif', 'jpeg', 'png', 'tiff', 'webp'] } reqwest = { version = "0.12.24", features = ["blocking", "json", "rustls-tls"]} bytemuck = "1.24.0" diff --git a/src/core/engine.rs b/src/core/engine.rs index 0812b5a3..98da3720 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -175,43 +175,48 @@ impl LLMEngine { let reporter: Arc>> = Arc::new(RwLock::new(Box::new(ProgressReporter::new(0)))); let handle = progress_worker(1, config.num_hidden_layers, &reporter); - let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?; - let transfer = if let Some(p_cfg) = &econfig.pd_config { - Some(Arc::new(Transfer::new( - p_cfg.clone(), - 0, - model_loaded.clone(), - stop_flag.clone(), - )?)) - } else { - None - }; - - let mut model_runner = ModelRunner::new( - model_type.clone(), - &vb, - #[cfg(not(feature = "nccl"))] - Rc::new(Comm::default()), - #[cfg(feature = "nccl")] - Rc::new( - Comm::from_rank( - device.as_cuda_device().unwrap().cuda_device(), + let mut model_runner = { + let _guard = candle_core::InferenceMode::enter(); + let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?; + let transfer = if let Some(p_cfg) = &econfig.pd_config { + Some(Arc::new(Transfer::new( + p_cfg.clone(), 0, - 1, - Id::new().unwrap(), - ) - .unwrap(), - ), - &mut econfig, - &config, - dtype, - is_rope_i, - device.clone(), - reporter, - transfer, - toktrie.clone(), - None, - )?; + model_loaded.clone(), + stop_flag.clone(), + )?)) + } else { + None + }; + + let runner = ModelRunner::new( + model_type.clone(), + &vb, + #[cfg(not(feature = "nccl"))] + Rc::new(Comm::default()), + #[cfg(feature = "nccl")] + Rc::new( + Comm::from_rank( + device.as_cuda_device().unwrap().cuda_device(), + 0, + 1, + Id::new().unwrap(), + ) + .unwrap(), + ), + &mut econfig, + &config, + dtype, + is_rope_i, + device.clone(), + reporter, + transfer, + toktrie.clone(), + None, + )?; + drop(vb); + runner + }; if !is_pd_server { //No graph capture for PD server diff --git a/src/models/gemma3/mod.rs b/src/models/gemma3/mod.rs index 68e435c2..600aafd6 100644 --- a/src/models/gemma3/mod.rs +++ b/src/models/gemma3/mod.rs @@ -538,11 +538,7 @@ impl Gemma3ForConditionalGeneration { } else { vb.pp("language_model.model.embed_tokens") }, - if is_qvar_builder || g_cfg.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let embed_scale = (config.text_config.hidden_size as f64).sqrt(); @@ -659,6 +655,17 @@ impl Gemma3ForConditionalGeneration { }) } + fn embed_forward(&self, input_ids: &Tensor) -> Result { + let xs = self.embed_tokens.forward(input_ids)?; + let xs = if (self.is_qvar_builder || self.g_cfg.quant.is_some()) && xs.dtype() != DType::F32 + { + xs.to_dtype(DType::F32)? + } else { + xs + }; + xs * self.embed_scale + } + fn vision_tower( &self, image_features: &Tensor, @@ -687,7 +694,7 @@ impl Gemma3ForConditionalGeneration { ) -> Result { let text_cfg = &self.config.text_config; // 1. Prepare Text Embeddings (Scaled) - let mut xs = (self.embed_tokens.forward(input_ids)? * self.embed_scale)?; + let mut xs = self.embed_forward(input_ids)?; // vision projection and embedding if let Some(images) = images { diff --git a/src/models/glm4.rs b/src/models/glm4.rs index bdce6101..33ea6608 100644 --- a/src/models/glm4.rs +++ b/src/models/glm4.rs @@ -206,11 +206,7 @@ impl GLM4ForCausalLM { } else { vb.pp("model.embed_tokens") }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -293,7 +289,12 @@ impl GLM4ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -319,7 +320,7 @@ impl GLM4ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/glm4_moe.rs b/src/models/glm4_moe.rs index 18c7c75b..b2d1855c 100644 --- a/src/models/glm4_moe.rs +++ b/src/models/glm4_moe.rs @@ -305,11 +305,7 @@ impl GLM4MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -393,7 +389,12 @@ impl GLM4MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -420,7 +421,7 @@ impl GLM4MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/llama.rs b/src/models/llama.rs index 2bf3e408..450512d4 100644 --- a/src/models/llama.rs +++ b/src/models/llama.rs @@ -173,11 +173,7 @@ impl LLaMaForCausalLM { } else { vb.pp("model.embed_tokens").clone() }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -262,7 +258,12 @@ impl LLaMaForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -287,7 +288,7 @@ impl LLaMaForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/phi4.rs b/src/models/phi4.rs index 05fe2759..7444eb08 100644 --- a/src/models/phi4.rs +++ b/src/models/phi4.rs @@ -513,11 +513,7 @@ impl Phi4ForCausalLM { } else { vb.pp("model.embed_tokens") }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(Phi4RotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -595,6 +591,15 @@ impl Phi4ForCausalLM { }) } + pub fn embed_forward(&self, xs: &Tensor) -> Result { + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } + } + fn forward_inner( &self, input_ids: &Tensor, @@ -620,7 +625,7 @@ impl Phi4ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/models/qwen3.rs b/src/models/qwen3.rs index 8d7f1bb0..6d653a42 100644 --- a/src/models/qwen3.rs +++ b/src/models/qwen3.rs @@ -214,11 +214,7 @@ impl Qwen3ForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -301,7 +297,12 @@ impl Qwen3ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -328,7 +329,7 @@ impl Qwen3ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { for ((k_cache, v_cache), (i, layer)) in diff --git a/src/models/qwen3_5.rs b/src/models/qwen3_5.rs index d7a339d5..b22dffd1 100644 --- a/src/models/qwen3_5.rs +++ b/src/models/qwen3_5.rs @@ -328,11 +328,7 @@ impl Qwen3_5ForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -475,7 +471,12 @@ impl Qwen3_5ForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -503,7 +504,7 @@ impl Qwen3_5ForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; let mut kv_cache_idx = 0usize; diff --git a/src/models/qwen3_5_moe.rs b/src/models/qwen3_5_moe.rs index a014fb65..529a9caf 100644 --- a/src/models/qwen3_5_moe.rs +++ b/src/models/qwen3_5_moe.rs @@ -441,11 +441,7 @@ impl Qwen3_5MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( @@ -586,7 +582,12 @@ impl Qwen3_5MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -613,7 +614,7 @@ impl Qwen3_5MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; let mut kv_cache_idx = 0usize; diff --git a/src/models/qwen3_moe.rs b/src/models/qwen3_moe.rs index 6ad153bc..52e30187 100644 --- a/src/models/qwen3_moe.rs +++ b/src/models/qwen3_moe.rs @@ -348,11 +348,7 @@ impl Qwen3MoEForCausalLM { } else { vb.pp(&format!("{}embed_tokens", prefix)) }, - if is_qvar_builder || config.quant.is_some() { - DType::F32 - } else { - dtype - }, + dtype, )?; let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( if is_qvar_builder || config.quant.is_some() { @@ -436,7 +432,12 @@ impl Qwen3MoEForCausalLM { } pub fn embed_forward(&self, xs: &Tensor) -> Result { - self.embed_tokens.forward(xs) + let xs = self.embed_tokens.forward(xs)?; + if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 { + xs.to_dtype(DType::F32) + } else { + Ok(xs) + } } fn forward_inner( @@ -463,7 +464,7 @@ impl Qwen3MoEForCausalLM { let mut xs = if embeded_inputs { input_ids.to_owned() } else { - self.embed_tokens.forward(input_ids)? + self.embed_forward(input_ids)? }; if let Some(kv_caches) = kv_caches { diff --git a/src/runner/runner.rs b/src/runner/runner.rs index ee3d8bed..2ba0cf13 100644 --- a/src/runner/runner.rs +++ b/src/runner/runner.rs @@ -126,31 +126,36 @@ fn main() -> anyhow::Result<()> { (None, false) }; - let vb = VarBuilderX::new( - &init_req.model_pathes, - init_req.is_gguf, - init_req.dtype.into(), - &device, - )?; let stream_kv = Some(stream.try_clone()?); let mut econfig = init_req.econfig.clone(); let toktrie = load_toktrie_from_path(&init_req.model_pathes.get_tokenizer_filename()) .map(Arc::new); #[allow(unused_mut)] - let mut runner = ModelRunner::new( - init_req.model_type, - &vb, - comm, - &mut econfig, - &init_req.config, - init_req.dtype.into(), - init_req.is_rope_i, - device, - progress_reporter, - transfer, - toktrie, - stream_kv, - )?; + let mut runner = { + let _guard = candle_core::InferenceMode::enter(); + let vb = VarBuilderX::new( + &init_req.model_pathes, + init_req.is_gguf, + init_req.dtype.into(), + &device, + )?; + let runner = ModelRunner::new( + init_req.model_type, + &vb, + comm, + &mut econfig, + &init_req.config, + init_req.dtype.into(), + init_req.is_rope_i, + device, + progress_reporter, + transfer, + toktrie, + stream_kv, + )?; + drop(vb); + runner + }; vllm_rs::log_info!( "Runner at rank {} created (PD config: {:?})!",