Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
fix #48 - make InferenceSession clonable
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Mar 29, 2023
1 parent cdb630d commit 086e7db
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ pub struct InferenceSession {
// Must be kept alive for the model
_session_ctx: ggml::Context,

// Original size of the memory used to create this context.
memory_size: usize,

// Parameters for the session.
params: InferenceSessionParameters,

Expand Down Expand Up @@ -102,6 +105,31 @@ impl InferenceSession {
.saturating_sub(self.params.repetition_penalty_last_n)..]
}
}
impl Clone for InferenceSession {
fn clone(&self) -> Self {
let context = ggml::Context::init(self.memory_size);
let memory_k = context.new_tensor_1d(
self.memory_k.get_type(),
self.memory_k.get_ne()[0].try_into().unwrap(),
);
let memory_v = context.new_tensor_1d(
self.memory_v.get_type(),
self.memory_v.get_ne()[0].try_into().unwrap(),
);

Self {
_session_ctx: context,
memory_size: self.memory_size,
params: self.params.clone(),
memory_k,
memory_v,
n_past: self.n_past.clone(),
mem_per_token: self.mem_per_token.clone(),
tokens: self.tokens.clone(),
last_logits: self.last_logits.clone(),
}
}
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process. Can be saved to disk.
Expand Down Expand Up @@ -542,7 +570,7 @@ pub struct EvaluateOutputRequest {
/// reported by ggml.
macro_rules! mulf {
($term:expr, $($terms:expr),*) => {
(($term as f64) $(* ($terms as f64))*) as u64
usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap()
};
}

Expand Down Expand Up @@ -702,12 +730,7 @@ impl Model {

let ctx_size = {
// Use 64-bit math to prevent overflow.
let n_embd = n_embd as u64;
let n_layer = n_layer as u64;
let n_vocab = n_vocab as u64;
let n_ff = n_ff as u64;

let mut ctx_size: u64 = 0;
let mut ctx_size: usize = 0;

ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings

Expand All @@ -730,9 +753,7 @@ impl Model {

ctx_size += (5 + 10 * n_layer) * 256; // object overhead

load_progress_callback(LoadProgress::ContextSize {
bytes: ctx_size.try_into()?,
});
load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size });

ctx_size
};
Expand Down Expand Up @@ -1092,7 +1113,7 @@ impl Model {
n_embd,
ggml::type_sizef(params.memory_v_type.into())
); // memory_v
ctx_size += (5 + 10 * n_layer as u64) * 256; // object overhead
ctx_size += (5 + 10 * n_layer) * 256; // object overhead
ctx_size
};

Expand All @@ -1106,6 +1127,7 @@ impl Model {

InferenceSession {
_session_ctx: session_ctx,
memory_size: ctx_size,
params,
memory_k,
memory_v,
Expand Down

0 comments on commit 086e7db

Please sign in to comment.