@@ -65,7 +65,8 @@ class Model:
6565    # subclasses should define this! 
6666    model_arch : gguf .MODEL_ARCH 
6767
68-     def  __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool , use_temp_file : bool , eager : bool , model_name : str  |  None ):
68+     def  __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool , use_temp_file : bool , eager : bool ,
69+                  model_name : str  |  None , split_max_tensors : int  =  0 , split_max_size : int  =  0 , dry_run : bool  =  False , small_first_shard : bool  =  False ):
6970        if  type (self ) is  Model :
7071            raise  TypeError (f"{ type (self ).__name__ !r}   should not be directly instantiated" )
7172        self .dir_model  =  dir_model 
@@ -96,7 +97,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
9697        ftype_lw : str  =  ftype_up .lower ()
9798        # allow templating the file name with the output ftype, useful with the "auto" ftype 
9899        self .fname_out  =  fname_out .parent  /  fname_out .name .format (ftype_lw , outtype = ftype_lw , ftype = ftype_lw , OUTTYPE = ftype_up , FTYPE = ftype_up )
99-         self .gguf_writer  =  gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file )
100+         self .gguf_writer  =  gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
101+                                            split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
100102
101103    @classmethod  
102104    def  __init_subclass__ (cls ):
@@ -332,6 +334,8 @@ def write(self):
332334        self .gguf_writer .close ()
333335
334336    def  write_vocab (self ):
337+         if  len (self .gguf_writer .tensors ) !=  1 :
338+             raise  ValueError ('Splitting the vocabulary is not supported' )
335339        self .gguf_writer .write_header_to_file (self .fname_out )
336340        self .gguf_writer .write_kv_data_to_file ()
337341        self .gguf_writer .close ()
@@ -2974,10 +2978,44 @@ def parse_args() -> argparse.Namespace:
29742978        "--verbose" , action = "store_true" ,
29752979        help = "increase output verbosity" ,
29762980    )
2981+     parser .add_argument (
2982+         "--split-max-tensors" , type = int , default = 0 ,
2983+         help = "max tensors in each split" ,
2984+     )
2985+     parser .add_argument (
2986+         "--split-max-size" , type = str , default = "0" ,
2987+         help = "max size per split N(M|G)" ,
2988+     )
2989+     parser .add_argument (
2990+         "--dry-run" , action = "store_true" ,
2991+         help = "only print out a split plan and exit, without writing any new files" ,
2992+     )
2993+     parser .add_argument (
2994+         "--no-tensor-first-split" , action = "store_true" ,
2995+         help = "do not add tensors to the first split (disabled by default)" 
2996+     )
29772997
29782998    return  parser .parse_args ()
29792999
29803000
3001+ def  split_str_to_n_bytes (split_str : str ) ->  int :
3002+     if  split_str .endswith ("K" ):
3003+         n  =  int (split_str [:- 1 ]) *  1000 
3004+     elif  split_str .endswith ("M" ):
3005+         n  =  int (split_str [:- 1 ]) *  1000  *  1000 
3006+     elif  split_str .endswith ("G" ):
3007+         n  =  int (split_str [:- 1 ]) *  1000  *  1000  *  1000 
3008+     elif  split_str .isnumeric ():
3009+         n  =  int (split_str )
3010+     else :
3011+         raise  ValueError (f"Invalid split size: { split_str }  , must be a number, optionally followed by K, M, or G" )
3012+ 
3013+     if  n  <  0 :
3014+         raise  ValueError (f"Invalid split size: { split_str }  , must be positive" )
3015+ 
3016+     return  n 
3017+ 
3018+ 
29813019def  main () ->  None :
29823020    args  =  parse_args ()
29833021
@@ -3010,6 +3048,10 @@ def main() -> None:
30103048        "auto" : gguf .LlamaFileType .GUESSED ,
30113049    }
30123050
3051+     if  args .use_temp_file  and  (args .split_max_tensors  >  0  or  args .split_max_size  !=  "0" ):
3052+         logger .error ("Error: Cannot use temp file when splitting" )
3053+         sys .exit (1 )
3054+ 
30133055    if  args .outfile  is  not   None :
30143056        fname_out  =  args .outfile 
30153057    else :
@@ -3027,7 +3069,10 @@ def main() -> None:
30273069            logger .error (f"Model { hparams ['architectures' ][0 ]}   is not supported" )
30283070            sys .exit (1 )
30293071
3030-         model_instance  =  model_class (dir_model , ftype_map [args .outtype ], fname_out , args .bigendian , args .use_temp_file , args .no_lazy , args .model_name )
3072+         model_instance  =  model_class (dir_model , ftype_map [args .outtype ], fname_out , args .bigendian , args .use_temp_file ,
3073+                                      args .no_lazy , args .model_name , split_max_tensors = args .split_max_tensors ,
3074+                                      split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
3075+                                      small_first_shard = args .no_tensor_first_split )
30313076
30323077        logger .info ("Set model parameters" )
30333078        model_instance .set_gguf_parameters ()
@@ -3038,13 +3083,13 @@ def main() -> None:
30383083        model_instance .gguf_writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
30393084
30403085        if  args .vocab_only :
3041-             logger .info (f "Exporting model vocab to ' { model_instance . fname_out } ' " )
3086+             logger .info ("Exporting model vocab... " )
30423087            model_instance .write_vocab ()
3088+             logger .info ("Model vocab successfully exported." )
30433089        else :
3044-             logger .info (f "Exporting model to ' { model_instance . fname_out } ' " )
3090+             logger .info ("Exporting model... " )
30453091            model_instance .write ()
3046- 
3047-         logger .info (f"Model successfully exported to '{ model_instance .fname_out }  '" )
3092+             logger .info ("Model successfully exported." )
30483093
30493094
30503095if  __name__  ==  '__main__' :
0 commit comments