Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #161 from tehmatt/tehmatt-spinner
Browse files Browse the repository at this point in the history
Use a spinner for model loading information
  • Loading branch information
philpax authored May 4, 2023
2 parents cf203d2 + 064ba8c commit 5d61d81
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 119 deletions.
113 changes: 56 additions & 57 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ log = { workspace = true }
rand = { workspace = true }

bincode = "1.3.3"
bytesize = "1.1"
env_logger = "0.10.0"
num_cpus = "1.15.0"
rustyline = "11.0.0"
spinners = "4.1.0"
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }

clap = { version = "4.1.8", features = ["derive"] }
color-eyre = { version = "0.6.2", default-features = false }
Expand Down
93 changes: 48 additions & 45 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,58 +311,61 @@ impl ModelLoad {
..Default::default()
};

let mut sp = Some(spinoff::Spinner::new(
spinoff::spinners::Dots2,
"Loading model...",
None,
));
let now = std::time::Instant::now();

let model = llm::load::<M>(
&self.model_path,
!self.no_mmap,
params,
load_progress_handler_log,
)
.wrap_err("Could not load model")?;

log::info!(
"Model fully loaded! Elapsed: {}ms",
now.elapsed().as_millis()
);
let model =
llm::load::<M>(
&self.model_path,
!self.no_mmap,
params,
move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
},
)
.wrap_err("Could not load model")?;

Ok(Box::new(model))
}
}

pub(crate) fn load_progress_handler_log(progress: LoadProgress) {
match progress {
LoadProgress::HyperparametersLoaded => {
log::debug!("Loaded hyperparameters")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
LoadProgress::Loaded {
byte_size,
tensor_count,
} => {
log::info!("Loading of model complete");
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
}
}
}

#[derive(Parser, Debug)]
pub struct PromptFile {
/// A file to read the prompt from.
Expand Down
9 changes: 5 additions & 4 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {
let file = File::open(&args.model_path)?;
let mut reader = BufReader::new(&file);
let mut loader: llm::Loader<M::Hyperparameters, _> =
llm::Loader::new(cli_args::load_progress_handler_log);
let mut loader: llm::Loader<M::Hyperparameters, _> = llm::Loader::new(|_| {
// We purposely do not print progress here, as we are only interested in the metadata
});

llm::ggml_format::load(&mut reader, &mut loader)?;

Expand Down Expand Up @@ -192,7 +193,7 @@ fn interactive<M: llm::KnownModel + 'static>(
.map(|pf| process_prompt(pf, &line))
.unwrap_or(line);

let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string());
let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None);
if let Err(InferenceError::ContextFull) = session.feed_prompt::<Infallible>(
model.as_ref(),
&inference_params,
Expand All @@ -201,7 +202,7 @@ fn interactive<M: llm::KnownModel + 'static>(
) {
log::error!("Prompt exceeds context window length.")
};
sp.stop();
sp.clear();

let res = session.infer_with_params::<Infallible>(
model.as_ref(),
Expand Down
Loading

0 comments on commit 5d61d81

Please sign in to comment.