Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #125 from iacore/llama-loader
Browse files Browse the repository at this point in the history
Standalone loader
  • Loading branch information
philpax authored Apr 22, 2023
2 parents bb0d7ea + b41fe14 commit 2e62a18
Show file tree
Hide file tree
Showing 18 changed files with 1,622 additions and 544 deletions.
412 changes: 309 additions & 103 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
members = [
"ggml-sys",
"ggml",
"ggml-loader",
"llama-rs",
"llama-cli",
"generate-ggml-bindings"
Expand Down
10 changes: 10 additions & 0 deletions ggml-loader/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "ggml-loader"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ggml = { path = "../ggml" }
thiserror = "*"
236 changes: 236 additions & 0 deletions ggml-loader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//! standalone model loader
//!
//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM.
#![allow(clippy::nonminimal_bool)]

pub mod util;

use std::ops::ControlFlow;
use util::*;

pub type ElementType = ggml::Type;

/// file type containing the model
#[derive(Debug, PartialEq, Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
pub enum ContainerType {
/// legacy format, oldest ggml tensor file format
GGML,
/// also legacy format, newer than GGML, older than GGJT
GGMF,
/// mmap-able format
GGJT,
}

impl ContainerType {
pub fn support_mmap(&self) -> bool {
match self {
ContainerType::GGML => false,
ContainerType::GGMF => false,
ContainerType::GGJT => true,
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum LoadError<T> {
#[error("invalid file magic number: {0}")]
InvalidMagic(u32),

#[error("invalid ggml format: version={0}")]
InvalidFormatVersion(u32),

#[error("{0}")]
Io(#[from] std::io::Error),

#[error("{0}")]
FailedCast(#[from] std::num::TryFromIntError),

/// return `ControlFlow::Break` from any of the `cb_*` function to trigger this error
#[error("user requested interrupt: {0}")]
UserInterrupted(T),

#[error("unsupported tensor dtype/f16_: {0}")]
UnsupportedElementType(i32),

/// sanity check failed
#[error("invariant broken: {0}")]
InvariantBroken(String),
}

#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: Vec<u8>,
pub n_dims: usize,
pub dims: [usize; 2],
pub n_elements: usize,
pub ftype: ElementType,
/// start of tensor - start of file
pub start_offset: u64,
}

/// Info in hyperparameter used for later loading tasks. Used in callback.
/// see [`LoadHandler::load_hyper_parameters`]
#[derive(Debug, Clone)]
pub struct PartialHyperparameters {
pub n_vocab: usize,
}

pub enum TensorDataTreatment<'a> {
CopyInto(&'a mut [u8]),
SeekPast {
/// should be `tensor.nbytes`
n_bytes: usize,
},
}

#[allow(unused_variables)]
pub trait LoadHandler<T, R: BufRead + Seek> {
fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn got_vocab_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow<T, PartialHyperparameters>;

/// callback to get tensor buffer to populate
///
/// # Returns
///
/// `None` to skip copying
/// `Some(buf)` to provide a buffer for copying weights into
fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow<T, TensorDataTreatment>;
}

#[test]
fn can_be_vtable() {
use std::mem::MaybeUninit;
let _a: MaybeUninit<Box<dyn LoadHandler<(), std::fs::File>>> = MaybeUninit::uninit();
}

pub fn load_model_from_reader<T, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<T, R>,
) -> Result<(), LoadError<T>> {
// Verify magic
let container_type: ContainerType = match read_u32(reader)? {
ggml::FILE_MAGIC_GGMF => ContainerType::GGMF,
ggml::FILE_MAGIC_GGJT => ContainerType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML,
magic => return Err(LoadError::InvalidMagic(magic)),
};
controlflow_to_result(handler.got_container_type(container_type))?;

// Load format version
match container_type {
ContainerType::GGMF | ContainerType::GGJT => {
let _version: u32 = match read_u32(reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(version)),
};
}
ContainerType::GGML => {}
}

// Load hyper params
let hparams = controlflow_to_result(handler.load_hyper_parameters(reader))?;
let n_vocab = hparams.n_vocab;

// Load vocabulary
for i in 0..n_vocab {
let len = read_u32(reader)?.try_into()?;
let token = read_bytes_with_len(reader, len)?;
let token_score = match container_type {
ContainerType::GGMF | ContainerType::GGJT => read_f32(reader)?,
ContainerType::GGML => {
// Legacy model, set empty score
0.
}
};
controlflow_to_result(handler.got_vocab_token(i, token, token_score))?;
}

// Load tensor data
match container_type {
ContainerType::GGMF | ContainerType::GGML => load_weights(reader, handler, false),
ContainerType::GGJT => load_weights(reader, handler, true),
}
}

/// # Params
///
/// `align`
/// align to 4 bytes before reading tensor weights
pub fn load_weights<T, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<T, R>,
align: bool,
) -> Result<(), LoadError<T>> {
while has_data_left(reader)? {
// load tensor header
let n_dims: usize = read_i32(reader)?.try_into()?;
let name_len = read_i32(reader)?;
let ftype = decode_element_type_res(read_i32(reader)?)?;

let mut n_elements: usize = 1;
let mut dims = [1usize, 1];
let ne_len = dims.len();
if !(n_dims <= ne_len) {
return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}")));
}
#[allow(clippy::needless_range_loop)]
for i in 0..n_dims {
let dim: usize = read_i32(reader)?.try_into()?;
dims[i] = dim;
n_elements *= dim;
}

// load tensor name
let name = read_bytes_with_len(reader, name_len.try_into()?)?;

// sanity check
match ftype {
ElementType::Q4_0 | ElementType::Q4_1 => {
if !(dims[0] % 64 == 0) {
return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0")));
}
}
_ => {}
}

// load tensor weights
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = if align {
(offset_curr + 31) & !31
} else {
offset_curr
};

let tensor_info = TensorInfo {
name,
dims,
n_dims,
n_elements,
ftype,
start_offset: offset_aligned,
};

match controlflow_to_result(handler.tensor_buffer(tensor_info))? {
TensorDataTreatment::CopyInto(buf) => {
if align {
reader.seek(SeekFrom::Start(offset_aligned))?;
}
reader.read_exact(buf)?;
}
TensorDataTreatment::SeekPast { n_bytes } => {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
}
}
}

Ok(())
}
77 changes: 77 additions & 0 deletions ggml-loader/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
pub use std::io::{BufRead, Seek, SeekFrom};
use std::ops::ControlFlow;

use crate::{ElementType, LoadError};

pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

pub fn read_i32(reader: &mut impl BufRead) -> Result<i32, std::io::Error> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_u32(reader: &mut impl BufRead) -> Result<u32, std::io::Error> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_f32(reader: &mut impl BufRead) -> Result<f32, std::io::Error> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, std::io::Error> {
let mut bytes = vec![0u8; len];
reader.read_exact(&mut bytes)?;
Ok(bytes)
}

// NOTE: Implementation from #![feature(buf_read_has_data_left)]
pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error> {
reader.fill_buf().map(|b| !b.is_empty())
}

pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
match ftype {
0 => Some(ggml::Type::F32),
1 => Some(ggml::Type::F16),
2 => Some(ggml::Type::Q4_0),
3 => Some(ggml::Type::Q4_1),
_ => None,
}
}

pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
match element_type {
ggml::Type::F32 => Some(0),
ggml::Type::F16 => Some(1),
ggml::Type::Q4_0 => Some(2),
ggml::Type::Q4_1 => Some(3),
_ => None,
}
}

pub fn decode_element_type_res<T>(ftype: i32) -> Result<ElementType, LoadError<T>> {
match decode_element_type(ftype) {
Some(x) => Ok(x),
None => Err(LoadError::UnsupportedElementType(ftype)),
}
}

pub fn controlflow_to_result<A, B>(x: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
match x {
ControlFlow::Continue(x) => Ok(x),
ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)),
}
}

pub fn result_to_controlflow<A, B, C: Into<A>>(x: Result<B, C>) -> ControlFlow<A, B> {
match x {
Ok(x) => ControlFlow::Continue(x),
Err(y) => ControlFlow::Break(y.into()),
}
}
Loading

0 comments on commit 2e62a18

Please sign in to comment.