diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 8bc98240..de7573d6 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -75,6 +75,9 @@ pub struct InferenceSession { // Must be kept alive for the model _session_ctx: ggml::Context, + // Original size of the memory used to create this context. + memory_size: usize, + // Parameters for the session. params: InferenceSessionParameters, @@ -102,6 +105,31 @@ impl InferenceSession { .saturating_sub(self.params.repetition_penalty_last_n)..] } } +impl Clone for InferenceSession { + fn clone(&self) -> Self { + let context = ggml::Context::init(self.memory_size); + let memory_k = context.new_tensor_1d( + self.memory_k.get_type(), + self.memory_k.get_ne()[0].try_into().unwrap(), + ); + let memory_v = context.new_tensor_1d( + self.memory_v.get_type(), + self.memory_v.get_ne()[0].try_into().unwrap(), + ); + + Self { + _session_ctx: context, + memory_size: self.memory_size, + params: self.params.clone(), + memory_k, + memory_v, + n_past: self.n_past.clone(), + mem_per_token: self.mem_per_token.clone(), + tokens: self.tokens.clone(), + last_logits: self.last_logits.clone(), + } + } +} #[derive(serde::Serialize, Clone, PartialEq)] /// A serializable snapshot of the inference process. Can be saved to disk. @@ -542,7 +570,7 @@ pub struct EvaluateOutputRequest { /// reported by ggml. macro_rules! mulf { ($term:expr, $($terms:expr),*) => { - (($term as f64) $(* ($terms as f64))*) as u64 + usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap() }; } @@ -702,12 +730,7 @@ impl Model { let ctx_size = { // Use 64-bit math to prevent overflow. - let n_embd = n_embd as u64; - let n_layer = n_layer as u64; - let n_vocab = n_vocab as u64; - let n_ff = n_ff as u64; - - let mut ctx_size: u64 = 0; + let mut ctx_size: usize = 0; ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings @@ -730,9 +753,7 @@ impl Model { ctx_size += (5 + 10 * n_layer) * 256; // object overhead - load_progress_callback(LoadProgress::ContextSize { - bytes: ctx_size.try_into()?, - }); + load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); ctx_size }; @@ -1092,7 +1113,7 @@ impl Model { n_embd, ggml::type_sizef(params.memory_v_type.into()) ); // memory_v - ctx_size += (5 + 10 * n_layer as u64) * 256; // object overhead + ctx_size += (5 + 10 * n_layer) * 256; // object overhead ctx_size }; @@ -1106,6 +1127,7 @@ impl Model { InferenceSession { _session_ctx: session_ctx, + memory_size: ctx_size, params, memory_k, memory_v,