Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Add loading code for ggjt
Browse files Browse the repository at this point in the history
Now it can load the model, but it's not working
  • Loading branch information
iacore committed Apr 6, 2023
1 parent af5415f commit db5bc8e
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 131 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,21 @@ impl Tensor {
/// # Safety
///
/// The data must not be mutated while being read from.
pub unsafe fn data(&self) -> *mut c_void {
pub unsafe fn data(&self) -> *const c_void {
self.with_alive_ctx(|| {
// SAFETY: The with_alive_call guarantees the context is alive
unsafe { *self.ptr.as_ptr() }.data
})
}

/// Set the tensor's data pointer (useful for mmap-ed data)
pub unsafe fn set_data(&self, data_ptr: *mut c_void) {
self.with_alive_ctx(|| {
// SAFETY: The with_alive_call guarantees the context is alive
unsafe { *self.ptr.as_ptr() }.data = data_ptr;
})
}

/// Number of elements in this tensor.
pub fn nelements(&self) -> usize {
self.with_alive_ctx(|| {
Expand Down
1 change: 1 addition & 0 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ rand = { workspace = true }
serde = { version = "1.0.156", features = ["derive"] }
serde_bytes = "0.11"
bincode = "1.3.3"
memmap2 = "0.5.10"
25 changes: 17 additions & 8 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::{
time,
};

use memmap2::Mmap;
use thiserror::Error;

use partial_sort::PartialSort;
Expand Down Expand Up @@ -66,6 +67,8 @@ pub struct Model {

tensors: HashMap<String, ggml::Tensor>,

mmap: Option<Mmap>,

// Must be kept alive for the model
_context: ggml::Context,
}
Expand Down Expand Up @@ -502,7 +505,7 @@ pub enum LoadError {
/// The name of the tensor.
tensor_name: String,
/// The format type that was encountered.
ftype: u32,
ftype: i32,
/// The path that failed.
path: PathBuf,
},
Expand Down Expand Up @@ -585,12 +588,13 @@ impl Model {

let main_path = path.as_ref();

let mut reader =
BufReader::new(
File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
source: e,
path: main_path.to_owned(),
})?,
})?;
let mut reader =
BufReader::new(
&file,
);

// Verify magic
Expand Down Expand Up @@ -732,7 +736,7 @@ impl Model {
// Initialize the context
let context = ggml::Context::init(ctx_size);

let model = {
let mut model = {
let mut tensors = HashMap::new();

let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab);
Expand Down Expand Up @@ -796,15 +800,20 @@ impl Model {
layers,
tensors,
_context: context,
mmap: None,
}
};

match model_type {
ModelType::GGMF | ModelType::Unversioned => {
load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)?
let file_offset = reader.stream_position()?;
drop(reader);
load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)?
}
ModelType::GGJT => {
load_weights_ggjt(reader, main_path, load_progress_callback, &model)?
let mmap = unsafe { Mmap::map(&file)? };
load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?;
model.mmap = Some(mmap);
}
}

Expand Down
Loading

0 comments on commit db5bc8e

Please sign in to comment.