Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
more code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
iacore committed Apr 6, 2023
1 parent ba9f91a commit 46fffc3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
36 changes: 24 additions & 12 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ struct Layer {
w3: ggml::Tensor,
}


/// Model Version
#[derive(Debug, PartialEq, Clone, Copy)]
pub(crate) enum ModelVersion {
GGMF,
GGJT,
Unversioned,
}

/// The weights for the LLaMA model. All the mutable state is split into a
/// separate struct `InferenceSession`.
pub struct Model {
Expand All @@ -68,6 +77,8 @@ pub struct Model {
tensors: HashMap<String, ggml::Tensor>,

mmap: Option<Mmap>,

version: ModelVersion,

// Must be kept alive for the model
_context: ggml::Context,
Expand Down Expand Up @@ -595,10 +606,10 @@ impl Model {
let mut reader = BufReader::new(&file);

// Verify magic
let model_type: ModelType = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ModelType::GGMF,
ggml::FILE_MAGIC_GGJT => ModelType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned,
let model_type: ModelVersion = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF,
ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned,
_ => {
return Err(LoadError::InvalidMagic {
path: main_path.to_owned(),
Expand All @@ -608,13 +619,13 @@ impl Model {

// Load format version
match model_type {
ModelType::GGMF | ModelType::GGJT => {
ModelVersion::GGMF | ModelVersion::GGJT => {
let _version: u32 = match read_u32(&mut reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion { value: version }),
};
}
ModelType::Unversioned => {}
ModelVersion::Unversioned => {}
}

// =================
Expand Down Expand Up @@ -651,8 +662,8 @@ impl Model {
for i in 0..hparams.n_vocab {
let len = match model_type {
// `read_i32` maybe a typo
ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize,
ModelType::GGJT => read_u32(&mut reader)? as usize,
ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize,
ModelVersion::GGJT => read_u32(&mut reader)? as usize,
};
let maybe_word = if len > 0 {
read_string(&mut reader, len)
Expand All @@ -673,12 +684,12 @@ impl Model {

// Token score, currently unused
match model_type {
ModelType::GGMF | ModelType::GGJT => {
ModelVersion::GGMF | ModelVersion::GGJT => {
if let Ok(score) = read_f32(&mut reader) {
id_to_token_score.push(score);
}
}
ModelType::Unversioned => {
ModelVersion::Unversioned => {
// Legacy model, set empty score
id_to_token_score.push(0.);
}
Expand Down Expand Up @@ -806,11 +817,12 @@ impl Model {
tensors,
_context: context,
mmap: None,
version: model_type,
}
};

match model_type {
ModelType::GGMF | ModelType::Unversioned => {
ModelVersion::GGMF | ModelVersion::Unversioned => {
let file_offset = reader.stream_position()?;
drop(reader);
load_weights_ggmf_or_unversioned(
Expand All @@ -820,7 +832,7 @@ impl Model {
&model,
)?
}
ModelType::GGJT => {
ModelVersion::GGJT => {
let mmap = unsafe { Mmap::map(&file)? };
load_weights_ggjt(
&mut reader,
Expand Down
7 changes: 0 additions & 7 deletions llama-rs/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,6 @@ fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error> {
reader.fill_buf().map(|b| !b.is_empty())
}

#[derive(PartialEq)]
pub(crate) enum ModelType {
GGMF,
GGJT,
Unversioned,
}

pub(crate) fn load_weights_ggmf_or_unversioned(
file_offset: u64,
main_path: &Path,
Expand Down

0 comments on commit 46fffc3

Please sign in to comment.