Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ categories = ["algorithms", "hardware-support", "science"]
license = "MIT"

[dependencies]
candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" }
candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "1e9d1a9" }
candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" }
candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" }
serde = { version = "1.0.190", features = ["serde_derive"] }
tokenizers = {version = "0.21.2", features = ["http"] }
hf-hub = "0.4.1"
Expand Down Expand Up @@ -44,7 +44,7 @@ ahash = "0.8.11"
reedline = "0.40.0"
pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true }
parking_lot = "0.12.4"
attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "af0b475" }
attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.1", rev = "29e4beb" }
once_cell = "1.21.3"
tqdm = "0.8.0"
futures = "0.3.31"
Expand All @@ -60,7 +60,7 @@ utoipa = { version = "4.2", features = ["axum_extras"] }
colored = { version = "3.0.0" }
tower-http = { version = "0.6.6", features = ["cors"] }
rustchatui = { git = "https://github.com/guoqingbao/rustchatui.git", rev = "68caad9" }
sysinfo = "0.37.2"
sysinfo = "0.38.3"
image = { version = "0.25.6", default-features = false, features = ['bmp', 'gif', 'jpeg', 'png', 'tiff', 'webp'] }
reqwest = { version = "0.12.24", features = ["blocking", "json", "rustls-tls"]}
bytemuck = "1.24.0"
Expand Down
77 changes: 41 additions & 36 deletions src/core/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,43 +175,48 @@ impl LLMEngine {
let reporter: Arc<RwLock<Box<dyn ProgressLike>>> =
Arc::new(RwLock::new(Box::new(ProgressReporter::new(0))));
let handle = progress_worker(1, config.num_hidden_layers, &reporter);
let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?;
let transfer = if let Some(p_cfg) = &econfig.pd_config {
Some(Arc::new(Transfer::new(
p_cfg.clone(),
0,
model_loaded.clone(),
stop_flag.clone(),
)?))
} else {
None
};

let mut model_runner = ModelRunner::new(
model_type.clone(),
&vb,
#[cfg(not(feature = "nccl"))]
Rc::new(Comm::default()),
#[cfg(feature = "nccl")]
Rc::new(
Comm::from_rank(
device.as_cuda_device().unwrap().cuda_device(),
let mut model_runner = {
let _guard = candle_core::InferenceMode::enter();
let vb = VarBuilderX::new(&model_pathes, is_gguf, dtype, &device)?;
let transfer = if let Some(p_cfg) = &econfig.pd_config {
Some(Arc::new(Transfer::new(
p_cfg.clone(),
0,
1,
Id::new().unwrap(),
)
.unwrap(),
),
&mut econfig,
&config,
dtype,
is_rope_i,
device.clone(),
reporter,
transfer,
toktrie.clone(),
None,
)?;
model_loaded.clone(),
stop_flag.clone(),
)?))
} else {
None
};

let runner = ModelRunner::new(
model_type.clone(),
&vb,
#[cfg(not(feature = "nccl"))]
Rc::new(Comm::default()),
#[cfg(feature = "nccl")]
Rc::new(
Comm::from_rank(
device.as_cuda_device().unwrap().cuda_device(),
0,
1,
Id::new().unwrap(),
)
.unwrap(),
),
&mut econfig,
&config,
dtype,
is_rope_i,
device.clone(),
reporter,
transfer,
toktrie.clone(),
None,
)?;
drop(vb);
runner
};

if !is_pd_server {
//No graph capture for PD server
Expand Down
19 changes: 13 additions & 6 deletions src/models/gemma3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,7 @@ impl Gemma3ForConditionalGeneration {
} else {
vb.pp("language_model.model.embed_tokens")
},
if is_qvar_builder || g_cfg.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;

let embed_scale = (config.text_config.hidden_size as f64).sqrt();
Expand Down Expand Up @@ -659,6 +655,17 @@ impl Gemma3ForConditionalGeneration {
})
}

fn embed_forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let xs = self.embed_tokens.forward(input_ids)?;
let xs = if (self.is_qvar_builder || self.g_cfg.quant.is_some()) && xs.dtype() != DType::F32
{
xs.to_dtype(DType::F32)?
} else {
xs
};
xs * self.embed_scale
}

fn vision_tower(
&self,
image_features: &Tensor,
Expand Down Expand Up @@ -687,7 +694,7 @@ impl Gemma3ForConditionalGeneration {
) -> Result<Tensor> {
let text_cfg = &self.config.text_config;
// 1. Prepare Text Embeddings (Scaled)
let mut xs = (self.embed_tokens.forward(input_ids)? * self.embed_scale)?;
let mut xs = self.embed_forward(input_ids)?;

// vision projection and embedding
if let Some(images) = images {
Expand Down
15 changes: 8 additions & 7 deletions src/models/glm4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,7 @@ impl GLM4ForCausalLM {
} else {
vb.pp("model.embed_tokens")
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;
let rotary_emb = Arc::new(ScalingRotaryEmbedding::new(
if is_qvar_builder || config.quant.is_some() {
Expand Down Expand Up @@ -293,7 +289,12 @@ impl GLM4ForCausalLM {
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(xs)
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
Expand All @@ -319,7 +320,7 @@ impl GLM4ForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};

if let Some(kv_caches) = kv_caches {
Expand Down
15 changes: 8 additions & 7 deletions src/models/glm4_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,7 @@ impl GLM4MoEForCausalLM {
} else {
vb.pp(&format!("{}embed_tokens", prefix))
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;
let rotary_emb = Arc::new(ScalingRotaryEmbedding::new(
if is_qvar_builder || config.quant.is_some() {
Expand Down Expand Up @@ -393,7 +389,12 @@ impl GLM4MoEForCausalLM {
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(xs)
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
Expand All @@ -420,7 +421,7 @@ impl GLM4MoEForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};

if let Some(kv_caches) = kv_caches {
Expand Down
15 changes: 8 additions & 7 deletions src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,7 @@ impl LLaMaForCausalLM {
} else {
vb.pp("model.embed_tokens").clone()
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;

let rotary_emb = Arc::new(ScalingRotaryEmbedding::new(
Expand Down Expand Up @@ -262,7 +258,12 @@ impl LLaMaForCausalLM {
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(xs)
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
Expand All @@ -287,7 +288,7 @@ impl LLaMaForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};

if let Some(kv_caches) = kv_caches {
Expand Down
17 changes: 11 additions & 6 deletions src/models/phi4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,7 @@ impl Phi4ForCausalLM {
} else {
vb.pp("model.embed_tokens")
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;
let rotary_emb = Arc::new(Phi4RotaryEmbedding::new(
if is_qvar_builder || config.quant.is_some() {
Expand Down Expand Up @@ -595,6 +591,15 @@ impl Phi4ForCausalLM {
})
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
&self,
input_ids: &Tensor,
Expand All @@ -620,7 +625,7 @@ impl Phi4ForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};

if let Some(kv_caches) = kv_caches {
Expand Down
15 changes: 8 additions & 7 deletions src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,7 @@ impl Qwen3ForCausalLM {
} else {
vb.pp(&format!("{}embed_tokens", prefix))
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;
let rotary_emb = Arc::new(ScalingRotaryEmbedding::new(
if is_qvar_builder || config.quant.is_some() {
Expand Down Expand Up @@ -301,7 +297,12 @@ impl Qwen3ForCausalLM {
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(xs)
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
Expand All @@ -328,7 +329,7 @@ impl Qwen3ForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};
if let Some(kv_caches) = kv_caches {
for ((k_cache, v_cache), (i, layer)) in
Expand Down
15 changes: 8 additions & 7 deletions src/models/qwen3_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,7 @@ impl Qwen3_5ForCausalLM {
} else {
vb.pp(&format!("{}embed_tokens", prefix))
},
if is_qvar_builder || config.quant.is_some() {
DType::F32
} else {
dtype
},
dtype,
)?;

let rotary_emb = Arc::new(ScalingRotaryEmbedding::new(
Expand Down Expand Up @@ -475,7 +471,12 @@ impl Qwen3_5ForCausalLM {
}

pub fn embed_forward(&self, xs: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(xs)
let xs = self.embed_tokens.forward(xs)?;
if (self.is_qvar_builder || self.config.quant.is_some()) && xs.dtype() != DType::F32 {
xs.to_dtype(DType::F32)
} else {
Ok(xs)
}
}

fn forward_inner(
Expand Down Expand Up @@ -503,7 +504,7 @@ impl Qwen3_5ForCausalLM {
let mut xs = if embeded_inputs {
input_ids.to_owned()
} else {
self.embed_tokens.forward(input_ids)?
self.embed_forward(input_ids)?
};

let mut kv_cache_idx = 0usize;
Expand Down
Loading