Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/llm-chain-llama-sys/llama.cpp
20 changes: 9 additions & 11 deletions crates/llm-chain-llama-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,21 @@ fn bindgen_test_layout_llama_token_data_array() {
)
);
}

const LLAMA_MAX_DEVICES: usize = 1;
pub type llama_progress_callback =
::std::option::Option<unsafe extern "C" fn(progress: f32, ctx: *mut ::std::os::raw::c_void)>;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct llama_context_params {
pub n_ctx: ::std::os::raw::c_int,
pub n_parts: ::std::os::raw::c_int,
pub n_batch: ::std::os::raw::c_int,
pub n_gpu_layers: ::std::os::raw::c_int,
pub main_gpu: ::std::os::raw::c_int,
pub tensor_split: [::std::os::raw::c_float; LLAM_MAX_DEVICES],

pub seed: ::std::os::raw::c_int,

pub f16_kv: bool,
pub logits_all: bool,
pub vocab_only: bool,
Expand Down Expand Up @@ -800,16 +807,7 @@ fn bindgen_test_layout_llama_context_params() {
stringify!(n_ctx)
)
);
assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).n_parts) as usize - ptr as usize },
4usize,
concat!(
"Offset of field: ",
stringify!(llama_context_params),
"::",
stringify!(n_parts)
)
);

assert_eq!(
unsafe { ::std::ptr::addr_of!((*ptr).seed) as usize - ptr as usize },
8usize,
Expand Down
39 changes: 36 additions & 3 deletions crates/llm-chain-llama/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use llm_chain::{executor, parameters, prompt};

use llm_chain::options;
use llm_chain::options::{ModelRef, Options};
use std::{env::args, error::Error};
/// This example demonstrates how to use the llm-chain-llama crate to generate text using a
/// LLaMA model.
///
Expand All @@ -9,11 +11,42 @@ use llm_chain::{executor, parameters, prompt};
/// cargo run --example simple /models/llama
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let exec = executor!(llama)?;
let raw_args: Vec<String> = args().collect();
let args = match &raw_args.len() {
2 => (raw_args[1].as_str(), "Rust is a cool programming language because"),
3 => (raw_args[1].as_str(), raw_args[2].as_str()),
_ => panic!("Usage: cargo run --release --example simple <path to model> <optional prompt>")
};

let model_path = args.0;
let prompt = args.1;
let opts = options!(
Model: ModelRef::from_path(model_path),
ModelType: "llama",
MaxContextSize: 512_usize,
NThreads: 4_usize,
MaxTokens: 0_usize,
TopK: 40_i32,
TopP: 0.95,
TfsZ: 1.0,
TypicalP: 1.0,
Temperature: 0.8,
RepeatPenalty: 1.1,
RepeatPenaltyLastN: 64_usize,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
Mirostat: 0_i32,
MirostatTau: 5.0,
MirostatEta: 0.1,
PenalizeNl: true,
StopSequence: vec!["\n".to_string()]
);
let exec = executor!(llama, opts.clone())?;

let res = prompt!("The Colors of the Rainbow are (in order): ")
let res = prompt!(prompt)
.run(&parameters!(), &exec)
.await?;

println!("{}", res.to_immediate().await?);
Ok(())
}
16 changes: 13 additions & 3 deletions crates/llm-chain-llama/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ use serde::{Deserialize, Serialize};
#[error("LLAMA.cpp returned error-code {0}")]
pub struct LLAMACPPErrorCode(i32);

const LLAMA_MAX_DEVICES: usize = 1; // corresponding to constant in llama.h
// Represents the configuration parameters for a LLamaContext.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextParams {
pub n_ctx: i32,
pub n_parts: i32,
pub n_batch: i32,
pub n_gpu_layers: i32,
pub main_gpu: i32,
pub tensor_split: [f32; LLAMA_MAX_DEVICES],
pub seed: i32,
pub f16_kv: bool,
pub vocab_only: bool,
Expand Down Expand Up @@ -56,7 +60,10 @@ impl From<ContextParams> for llama_context_params {
fn from(params: ContextParams) -> Self {
llama_context_params {
n_ctx: params.n_ctx,
n_parts: params.n_parts,
n_batch: params.n_batch,
n_gpu_layers: params.n_gpu_layers,
main_gpu: params.main_gpu,
tensor_split: params.tensor_split,
seed: params.seed,
f16_kv: params.f16_kv,
logits_all: false,
Expand All @@ -74,7 +81,10 @@ impl From<llama_context_params> for ContextParams {
fn from(params: llama_context_params) -> Self {
ContextParams {
n_ctx: params.n_ctx,
n_parts: params.n_parts,
n_batch: params.n_batch,
n_gpu_layers: params.n_gpu_layers,
main_gpu: params.main_gpu,
tensor_split: params.tensor_split,
seed: params.seed,
f16_kv: params.f16_kv,
vocab_only: params.vocab_only,
Expand Down
9 changes: 4 additions & 5 deletions crates/llm-chain-llama/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Executor {

// Run the LLAMA model with the provided input and generate output.
// Executes the model with the provided input and context parameters.
fn run_model(&self, input: LlamaInvocation) -> Output {
async fn run_model(&self, input: LlamaInvocation) -> Output {
let (sender, output) = Output::new_stream();
// Tokenize the stop sequence and input prompt.
let context = self.context.clone();
Expand All @@ -62,7 +62,6 @@ impl Executor {
async move {
let context_size = context_size;
let context = context.lock().await;

let tokenized_stop_prompt = tokenize(
&context,
input
Expand All @@ -87,7 +86,7 @@ impl Executor {

// Embd contains the prompt and the completion. The longer the prompt, the shorter the completion.
let mut embd = tokenized_input.clone();

// Evaluate the prompt in full.
bail!(
context
Expand Down Expand Up @@ -180,7 +179,7 @@ impl Executor {
}
}
}
});
}).await.unwrap().await;

output
}
Expand Down Expand Up @@ -210,7 +209,7 @@ impl ExecutorTrait for Executor {
async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
let invocation = LlamaInvocation::new(self.get_cascade(options), prompt)
.ok_or(ExecutorError::InvalidOptions)?;
Ok(self.run_model(invocation))
Ok(self.run_model(invocation).await)
}

fn tokens_used(
Expand Down
7 changes: 4 additions & 3 deletions crates/llm-chain-llama/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ impl LlamaInvocation {
pub(crate) fn new(opt: OptionsCascade, prompt: &Prompt) -> Option<LlamaInvocation> {
opt_extract!(opt, n_threads, NThreads);
opt_extract!(opt, n_tok_predict, MaxTokens);
opt_extract!(opt, token_bias, TokenBias);
// Skip TokenBias for now
//opt_extract!(opt, token_bias, TokenBias);
opt_extract!(opt, top_k, TopK);
opt_extract!(opt, top_p, TopP);
opt_extract!(opt, tfs_z, TfsZ);
Expand All @@ -59,8 +60,8 @@ impl LlamaInvocation {
opt_extract!(opt, penalize_nl, PenalizeNl);
opt_extract!(opt, stop_sequence, StopSequence);

let logit_bias = token_bias.as_i32_f32_hashmap()?;

let logit_bias = HashMap::<i32,f32>::new();// token_bias.as_i32_f32_hashmap()?;
Some(LlamaInvocation {
n_threads: *n_threads as i32,
n_tok_predict: *n_tok_predict,
Expand Down