diff --git a/llama-loader/src/lib.rs b/llama-loader/src/lib.rs index 8cf4ed8d..3a0ab8f4 100644 --- a/llama-loader/src/lib.rs +++ b/llama-loader/src/lib.rs @@ -109,7 +109,7 @@ pub trait LoadHandler { } /// # Returns - /// + /// /// `None` to skip copying /// `Some(buf)` to provide a buffer for copying weights into fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow> { @@ -242,10 +242,9 @@ fn load_weights_ggjt( n_dims, n_elements, ftype, - start_offset: offset_aligned + start_offset: offset_aligned, }; - let type_size = ggml::type_size(ftype); if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? { reader.seek(SeekFrom::Start(offset_aligned))?; @@ -258,7 +257,9 @@ fn load_weights_ggjt( reader.read_exact(buf)?; } else { // skip if no buffer is given - reader.seek(SeekFrom::Start(offset_aligned + (type_size * n_elements) as u64))?; + reader.seek(SeekFrom::Start( + offset_aligned + (type_size * n_elements) as u64, + ))?; } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 1ff7fa62..df5b2da0 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -24,6 +24,7 @@ use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; +use llama_loader::util::*; use llama_loader::{decode_element_type, ContainerType}; /// dummy struct @@ -609,48 +610,6 @@ impl Model { })?; let mut reader = BufReader::new(&file); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_bytes_with_len( - reader: &mut impl BufRead, - len: usize, - ) -> Result, LoadError> { - let mut bytes = vec![0u8; len]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: len, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) - } - // Verify magic let model_type: ContainerType = match read_u32(&mut reader)? { ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, @@ -710,7 +669,7 @@ impl Model { for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len)?; + let token = read_bytes_with_len(&mut reader, len.try_into()?)?; max_token_length = max_token_length.max(token.len()); id_to_token.push(token.clone()); token_to_id.insert(token, TokenId::try_from(i)?); diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index ad2b9ab0..b76be7fe 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -6,29 +6,7 @@ use std::{ use crate::ElementType; use crate::{util, LoadError, LoadProgress, Model}; use llama_loader::decode_element_type; - -pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) -} - -pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) -} - -pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) -} +use llama_loader::util::*; /// Helper function. Reads a string from the buffer and returns it. pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { @@ -43,11 +21,6 @@ pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result Result { - reader.fill_buf().map(|b| !b.is_empty()) -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path,