55import os
66from io import StringIO
77from pathlib import Path
8+ from typing import Any , Dict , Iterator , Tuple
89
910import numpy as np
1011from tvm import tir
@@ -83,7 +84,7 @@ def _check_param(name: str, param: NDArray):
8384 nonlocal named_params
8485 if name not in named_params :
8586 raise ValueError (f"Parameter not found in model: { name } " )
86- if name in param_dict :
87+ if name in param_names :
8788 raise ValueError (f"Duplication: Parameter { name } already computed" )
8889
8990 # Check shape (possibly dynamic)
@@ -112,20 +113,43 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
112113 del named_params [name ]
113114
114115 # load and quantize
115- param_dict = {}
116+ param_names = set ()
116117 total_bytes = 0.0
117- with Target .from_device (args .device ), tqdm .redirect ():
118- loader = LOADER [args .source_format ](
119- path = args .source ,
120- extern_param_map = args .model .source [args .source_format ](model_config , args .quantization ),
121- quantize_param_map = quantize_map ,
122- )
123- for name , param in loader .load (device = args .device , preshard_funcs = preshard_funcs ):
124- _check_param (name , param )
125- param = param .copyto (cpu_device ())
126- param_dict [name ] = param
127- total_bytes += math .prod (param .shape ) * np .dtype (param .dtype ).itemsize
128- total_params = loader .stats .total_param_num
118+ total_params : int
119+
120+ def _param_generator () -> Iterator [Tuple [str , NDArray ]]:
121+ nonlocal total_params , total_bytes
122+ with Target .from_device (args .device ), tqdm .redirect ():
123+ loader = LOADER [args .source_format ](
124+ path = args .source ,
125+ extern_param_map = args .model .source [args .source_format ](
126+ model_config , args .quantization
127+ ),
128+ quantize_param_map = quantize_map ,
129+ )
130+ for name , param in loader .load (device = args .device , preshard_funcs = preshard_funcs ):
131+ _check_param (name , param )
132+ param_names .add (name )
133+ param = param .copyto (cpu_device ())
134+ total_bytes += math .prod (param .shape ) * np .dtype (param .dtype ).itemsize
135+ yield name , param
136+ total_params = loader .stats .total_param_num
137+
138+ def _metadata_callback () -> Dict [str , Any ]:
139+ return {
140+ "ParamSize" : len (param_names ),
141+ "ParamBytes" : total_bytes ,
142+ "BitsPerParam" : total_bytes * 8.0 / total_params ,
143+ }
144+
145+ # dump to output directory
146+ tvmjs .dump_ndarray_cache (
147+ _param_generator (),
148+ str (args .output ),
149+ meta_data = _metadata_callback ,
150+ encode_format = "f32-to-bf16" ,
151+ show_progress = False ,
152+ )
129153 if named_params :
130154 raise ValueError (f"Parameter not found in source: { ', ' .join (named_params .keys ())} " )
131155 # Log necessary statistics
@@ -140,17 +164,6 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var
140164 green ("Bits per parameter" ),
141165 total_bytes * 8.0 / total_params ,
142166 )
143- # dump to output directory
144- tvmjs .dump_ndarray_cache (
145- param_dict ,
146- str (args .output ),
147- meta_data = {
148- "ParamSize" : len (param_dict ),
149- "ParamBytes" : total_bytes ,
150- "BitsPerParam" : total_bytes * 8.0 / total_params ,
151- },
152- encode_format = "f32-to-bf16" ,
153- )
154167 logger .info ("Saved to directory: %s" , bold (str (args .output )))
155168
156169
0 commit comments