diff --git a/crates/llm-chain-llama-sys/llama.cpp b/crates/llm-chain-llama-sys/llama.cpp index 173d0e64..5c64a095 160000 --- a/crates/llm-chain-llama-sys/llama.cpp +++ b/crates/llm-chain-llama-sys/llama.cpp @@ -1 +1 @@ -Subproject commit 173d0e6419e8f8f3c1f4f13201b777f4c60629f3 +Subproject commit 5c64a0952ee58b2d742ee84e8e3d43cce5d366db diff --git a/crates/llm-chain-llama-sys/src/bindings.rs b/crates/llm-chain-llama-sys/src/bindings.rs index 21d7e923..f94dd1ac 100644 --- a/crates/llm-chain-llama-sys/src/bindings.rs +++ b/crates/llm-chain-llama-sys/src/bindings.rs @@ -759,14 +759,21 @@ fn bindgen_test_layout_llama_token_data_array() { ) ); } + +const LLAMA_MAX_DEVICES: usize = 1; pub type llama_progress_callback = ::std::option::Option; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct llama_context_params { pub n_ctx: ::std::os::raw::c_int, - pub n_parts: ::std::os::raw::c_int, + pub n_batch: ::std::os::raw::c_int, + pub n_gpu_layers: ::std::os::raw::c_int, + pub main_gpu: ::std::os::raw::c_int, + pub tensor_split: [::std::os::raw::c_float; LLAM_MAX_DEVICES], + pub seed: ::std::os::raw::c_int, + pub f16_kv: bool, pub logits_all: bool, pub vocab_only: bool, @@ -800,16 +807,7 @@ fn bindgen_test_layout_llama_context_params() { stringify!(n_ctx) ) ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).n_parts) as usize - ptr as usize }, - 4usize, - concat!( - "Offset of field: ", - stringify!(llama_context_params), - "::", - stringify!(n_parts) - ) - ); + assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).seed) as usize - ptr as usize }, 8usize, diff --git a/crates/llm-chain-llama/examples/simple.rs b/crates/llm-chain-llama/examples/simple.rs index 8ff95148..e5c088f3 100644 --- a/crates/llm-chain-llama/examples/simple.rs +++ b/crates/llm-chain-llama/examples/simple.rs @@ -1,5 +1,7 @@ use llm_chain::{executor, parameters, prompt}; - +use llm_chain::options; +use llm_chain::options::{ModelRef, Options}; +use std::{env::args, error::Error}; /// This example demonstrates how to use the llm-chain-llama crate to generate text using a /// LLaMA model. /// @@ -9,11 +11,42 @@ use llm_chain::{executor, parameters, prompt}; /// cargo run --example simple /models/llama #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - let exec = executor!(llama)?; + let raw_args: Vec = args().collect(); + let args = match &raw_args.len() { + 2 => (raw_args[1].as_str(), "Rust is a cool programming language because"), + 3 => (raw_args[1].as_str(), raw_args[2].as_str()), + _ => panic!("Usage: cargo run --release --example simple ") + }; + + let model_path = args.0; + let prompt = args.1; + let opts = options!( + Model: ModelRef::from_path(model_path), + ModelType: "llama", + MaxContextSize: 512_usize, + NThreads: 4_usize, + MaxTokens: 0_usize, + TopK: 40_i32, + TopP: 0.95, + TfsZ: 1.0, + TypicalP: 1.0, + Temperature: 0.8, + RepeatPenalty: 1.1, + RepeatPenaltyLastN: 64_usize, + FrequencyPenalty: 0.0, + PresencePenalty: 0.0, + Mirostat: 0_i32, + MirostatTau: 5.0, + MirostatEta: 0.1, + PenalizeNl: true, + StopSequence: vec!["\n".to_string()] + ); + let exec = executor!(llama, opts.clone())?; - let res = prompt!("The Colors of the Rainbow are (in order): ") + let res = prompt!(prompt) .run(¶meters!(), &exec) .await?; + println!("{}", res.to_immediate().await?); Ok(()) } diff --git a/crates/llm-chain-llama/src/context.rs b/crates/llm-chain-llama/src/context.rs index 0ffe8a67..4d5a801e 100644 --- a/crates/llm-chain-llama/src/context.rs +++ b/crates/llm-chain-llama/src/context.rs @@ -20,11 +20,15 @@ use serde::{Deserialize, Serialize}; #[error("LLAMA.cpp returned error-code {0}")] pub struct LLAMACPPErrorCode(i32); +const LLAMA_MAX_DEVICES: usize = 1; // corresponding to constant in llama.h // Represents the configuration parameters for a LLamaContext. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ContextParams { pub n_ctx: i32, - pub n_parts: i32, + pub n_batch: i32, + pub n_gpu_layers: i32, + pub main_gpu: i32, + pub tensor_split: [f32; LLAMA_MAX_DEVICES], pub seed: i32, pub f16_kv: bool, pub vocab_only: bool, @@ -56,7 +60,10 @@ impl From for llama_context_params { fn from(params: ContextParams) -> Self { llama_context_params { n_ctx: params.n_ctx, - n_parts: params.n_parts, + n_batch: params.n_batch, + n_gpu_layers: params.n_gpu_layers, + main_gpu: params.main_gpu, + tensor_split: params.tensor_split, seed: params.seed, f16_kv: params.f16_kv, logits_all: false, @@ -74,7 +81,10 @@ impl From for ContextParams { fn from(params: llama_context_params) -> Self { ContextParams { n_ctx: params.n_ctx, - n_parts: params.n_parts, + n_batch: params.n_batch, + n_gpu_layers: params.n_gpu_layers, + main_gpu: params.main_gpu, + tensor_split: params.tensor_split, seed: params.seed, f16_kv: params.f16_kv, vocab_only: params.vocab_only, diff --git a/crates/llm-chain-llama/src/executor.rs b/crates/llm-chain-llama/src/executor.rs index 77bab488..549a0f42 100644 --- a/crates/llm-chain-llama/src/executor.rs +++ b/crates/llm-chain-llama/src/executor.rs @@ -51,7 +51,7 @@ impl Executor { // Run the LLAMA model with the provided input and generate output. // Executes the model with the provided input and context parameters. - fn run_model(&self, input: LlamaInvocation) -> Output { + async fn run_model(&self, input: LlamaInvocation) -> Output { let (sender, output) = Output::new_stream(); // Tokenize the stop sequence and input prompt. let context = self.context.clone(); @@ -62,7 +62,6 @@ impl Executor { async move { let context_size = context_size; let context = context.lock().await; - let tokenized_stop_prompt = tokenize( &context, input @@ -87,7 +86,7 @@ impl Executor { // Embd contains the prompt and the completion. The longer the prompt, the shorter the completion. let mut embd = tokenized_input.clone(); - + // Evaluate the prompt in full. bail!( context @@ -180,7 +179,7 @@ impl Executor { } } } - }); + }).await.unwrap().await; output } @@ -210,7 +209,7 @@ impl ExecutorTrait for Executor { async fn execute(&self, options: &Options, prompt: &Prompt) -> Result { let invocation = LlamaInvocation::new(self.get_cascade(options), prompt) .ok_or(ExecutorError::InvalidOptions)?; - Ok(self.run_model(invocation)) + Ok(self.run_model(invocation).await) } fn tokens_used( diff --git a/crates/llm-chain-llama/src/options.rs b/crates/llm-chain-llama/src/options.rs index baa12e80..bff8236f 100644 --- a/crates/llm-chain-llama/src/options.rs +++ b/crates/llm-chain-llama/src/options.rs @@ -43,7 +43,8 @@ impl LlamaInvocation { pub(crate) fn new(opt: OptionsCascade, prompt: &Prompt) -> Option { opt_extract!(opt, n_threads, NThreads); opt_extract!(opt, n_tok_predict, MaxTokens); - opt_extract!(opt, token_bias, TokenBias); + // Skip TokenBias for now + //opt_extract!(opt, token_bias, TokenBias); opt_extract!(opt, top_k, TopK); opt_extract!(opt, top_p, TopP); opt_extract!(opt, tfs_z, TfsZ); @@ -59,8 +60,8 @@ impl LlamaInvocation { opt_extract!(opt, penalize_nl, PenalizeNl); opt_extract!(opt, stop_sequence, StopSequence); - let logit_bias = token_bias.as_i32_f32_hashmap()?; - + let logit_bias = HashMap::::new();// token_bias.as_i32_f32_hashmap()?; + Some(LlamaInvocation { n_threads: *n_threads as i32, n_tok_predict: *n_tok_predict,