diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 1ff7fa62..1b08420e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -25,6 +25,7 @@ use thiserror::Error; use memmap2::Mmap; use llama_loader::{decode_element_type, ContainerType}; +use llama_loader::util::*; /// dummy struct #[cfg(not(feature = "mmap"))] @@ -609,48 +610,6 @@ impl Model { })?; let mut reader = BufReader::new(&file); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, - ) -> Result, LoadError> { - let mut bytes = vec![0u8; len]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: len, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) - } - // Verify magic let model_type: ContainerType = match read_u32(&mut reader)? { ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, @@ -710,7 +669,7 @@ impl Model { for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len)?; + let token = read_bytes_with_len(&mut reader, len.try_into()?)?; max_token_length = max_token_length.max(token.len()); id_to_token.push(token.clone()); token_to_id.insert(token, TokenId::try_from(i)?); diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index ad2b9ab0..b76be7fe 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -6,29 +6,7 @@ use std::{ use crate::ElementType; use crate::{util, LoadError, LoadProgress, Model}; use llama_loader::decode_element_type; - -pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) -} - -pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) -} +use llama_loader::util::*; /// Helper function. Reads a string from the buffer and returns it. pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { @@ -43,11 +21,6 @@ pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result Result { - reader.fill_buf().map(|b| !b.is_empty()) -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path,