Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #135 from philpax/split-up-modules
Browse files Browse the repository at this point in the history
Split up modules
  • Loading branch information
philpax authored Apr 13, 2023
2 parents 5db8b4f + ec58e46 commit 4938dad
Show file tree
Hide file tree
Showing 9 changed files with 1,903 additions and 1,844 deletions.
93 changes: 46 additions & 47 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,60 +260,59 @@ pub struct ModelLoad {
pub num_ctx_tokens: usize,
}
impl ModelLoad {
pub fn load(&self) -> (llama_rs::Model, llama_rs::Vocabulary) {
let (model, vocabulary) =
llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| {
use llama_rs::LoadProgress;
match progress {
LoadProgress::HyperparametersLoaded(hparams) => {
log::debug!("Loaded hyperparameters {hparams:#?}")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::PartLoading {
file,
pub fn load(&self) -> llama_rs::Model {
let model = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| {
use llama_rs::LoadProgress;
match progress {
LoadProgress::HyperparametersLoaded(hparams) => {
log::debug!("Loaded hyperparameters {hparams:#?}")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::PartLoading {
file,
current_part,
total_parts,
} => {
let current_part = current_part + 1;
log::info!(
"Loading model part {}/{} from '{}'\n",
current_part,
total_parts,
} => {
let current_part = current_part + 1;
log::info!(
"Loading model part {}/{} from '{}'\n",
current_part,
total_parts,
file.to_string_lossy(),
)
}
LoadProgress::PartTensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
LoadProgress::PartLoaded {
file,
byte_size,
tensor_count,
} => {
log::info!("Loading of '{}' complete", file.to_string_lossy());
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
file.to_string_lossy(),
)
}
LoadProgress::PartTensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
})
.expect("Could not load model");
LoadProgress::PartLoaded {
file,
byte_size,
tensor_count,
} => {
log::info!("Loading of '{}' complete", file.to_string_lossy());
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
}
}
})
.expect("Could not load model");

log::info!("Model fully loaded!");

(model, vocabulary)
model
}
}

Expand Down
11 changes: 4 additions & 7 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fn main() {
fn infer(args: &cli_args::Infer) {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let model = args.model_load.load();
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
args.persist_session.as_deref(),
Expand All @@ -39,7 +39,6 @@ fn infer(args: &cli_args::Infer) {
let mut rng = args.generate.rng();
let res = session.inference_with_prompt::<Infallible>(
&model,
&vocabulary,
&inference_params,
&prompt,
args.generate.num_predict,
Expand Down Expand Up @@ -73,8 +72,8 @@ fn infer(args: &cli_args::Infer) {

fn dump_tokens(args: &cli_args::DumpTokens) {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let (_, vocabulary) = args.model_load.load();
let toks = match vocabulary.tokenize(&prompt, false) {
let model = args.model_load.load();
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
log::error!("Could not tokenize prompt: {e}");
Expand Down Expand Up @@ -106,7 +105,7 @@ fn interactive(
) {
let prompt_file = args.prompt_file.contents();
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let model = args.model_load.load();
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
None,
Expand Down Expand Up @@ -135,7 +134,6 @@ fn interactive(
let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string());
if let Err(InferenceError::ContextFull) = session.feed_prompt::<Infallible>(
&model,
&vocabulary,
&inference_params,
&prompt,
|_| Ok(()),
Expand All @@ -146,7 +144,6 @@ fn interactive(

let res = session.inference_with_prompt::<Infallible>(
&model,
&vocabulary,
&inference_params,
"",
args.generate.num_predict,
Expand Down
2 changes: 1 addition & 1 deletion llama-cli/src/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn read_or_create_session(
let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || {
format!("Could not deserialize inference session from {path:?}")
});
let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || {
let session = unwrap_or_exit(InferenceSession::from_snapshot(snapshot, model), || {
format!("Could not convert snapshot from {path:?} to session")
});
log::info!("Loaded inference session from {path:?}");
Expand Down
Loading

0 comments on commit 4938dad

Please sign in to comment.