@@ -148,9 +148,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
148148 tensor_names_from_parts .update (model_part .keys ())
149149
150150 for name in model_part .keys ():
151- data = model_part .get_tensor (name ) if self .is_safetensors else model_part [name ]
152- if self .lazy :
153- data = LazyTorchTensor .from_eager (data )
151+ if self .is_safetensors :
152+ if self .lazy :
153+ data = model_part .get_slice (name )
154+ data = LazyTorchTensor .from_safetensors_slice (data )
155+ else :
156+ data = model_part .get_tensor (name )
157+ else :
158+ data = model_part [name ]
159+ if self .lazy :
160+ data = LazyTorchTensor .from_eager (data )
154161 yield name , data
155162
156163 # only verify tensor name presence; it doesn't matter if they are not in the right files
@@ -3435,6 +3442,27 @@ class LazyTorchTensor(gguf.LazyBase):
34353442 torch .float32 : np .float32 ,
34363443 }
34373444
3445+ # used for safetensors slices
3446+ # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
3447+ # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
3448+ _dtype_str_map : dict [str , torch .dtype ] = {
3449+ "F64" : torch .float64 ,
3450+ "F32" : torch .float32 ,
3451+ "BF16" : torch .bfloat16 ,
3452+ "F16" : torch .float16 ,
3453+ # "U64": torch.uint64,
3454+ "I64" : torch .int64 ,
3455+ # "U32": torch.uint32,
3456+ "I32" : torch .int32 ,
3457+ # "U16": torch.uint16,
3458+ "I16" : torch .int16 ,
3459+ "U8" : torch .uint8 ,
3460+ "I8" : torch .int8 ,
3461+ "BOOL" : torch .bool ,
3462+ "F8_E4M3" : torch .float8_e4m3fn ,
3463+ "F8_E5M2" : torch .float8_e5m2 ,
3464+ }
3465+
34383466 def numpy (self ) -> gguf .LazyNumpyTensor :
34393467 dtype = self ._dtype_map [self .dtype ]
34403468 return gguf .LazyNumpyTensor (
@@ -3448,6 +3476,13 @@ def numpy(self) -> gguf.LazyNumpyTensor:
34483476 def meta_with_dtype_and_shape (cls , dtype : torch .dtype , shape : torch .Size ) -> Tensor :
34493477 return torch .empty (size = shape , dtype = dtype , device = "meta" )
34503478
3479+ @classmethod
3480+ def from_safetensors_slice (cls , st_slice : Any ) -> Tensor :
3481+ dtype = cls ._dtype_str_map [st_slice .get_dtype ()]
3482+ shape = st_slice .get_shape ()
3483+ lazy = cls (meta = cls .meta_with_dtype_and_shape (dtype , shape ), args = (st_slice ,), func = lambda s : s [0 ][:])
3484+ return cast (torch .Tensor , lazy )
3485+
34513486 @classmethod
34523487 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
34533488 del types # unused
0 commit comments