1717"""Namespace to store utilities for building web runtime."""
1818import hashlib
1919import json
20+ import math
2021import os
2122import shutil
2223
2324# pylint: disable=unused-import
2425import sys
25- from typing import Mapping , Union
26+ from types import GeneratorType
27+ from typing import Iterator , Mapping , Tuple , Union
2628
2729import numpy as np
2830
@@ -149,37 +151,48 @@ def pending_nbytes(self):
149151
150152
151153def dump_ndarray_cache (
152- params : Mapping [str , Union [np .ndarray , tvm .runtime .NDArray ]],
154+ params : Union [
155+ Mapping [str , Union [np .ndarray , tvm .runtime .NDArray ]],
156+ Iterator [Tuple [str , Union [np .ndarray , tvm .runtime .NDArray ]]],
157+ ],
153158 cache_dir : str ,
154159 encode_format = "f32-to-bf16" ,
155160 meta_data = None ,
156161 shard_cap_mb = 32 ,
162+ show_progress : bool = True ,
157163):
158164 """Dump parameters to NDArray cache.
159165
160166 Parameters
161167 ----------
162- params: Mapping[str, tvm.runtime.NDArray],
163- The parameter dictionary
168+ params: Union[
169+ Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
170+ Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
171+ ]
172+ The parameter dictionary or generator
164173
165174 cache_dir: str
166175 The path to the cache
167176
168177 encode_format: {"f32-to-bf16", "raw"}
169178 Encoding format.
170179
171- meta_data: json-compatible-struct
172- Extra meta_data to be stored in the cache json file.
180+ meta_data: json-compatible-struct or Callable[[], Any]
181+ Extra meta_data to be stored in the cache json file,
182+ or a callable that returns the metadata.
173183
174184 shard_cap_mb: int
175185 Maxinum number of MB to be kept per shard
186+
187+ show_progress: bool
188+ A boolean indicating if to show the dump progress.
176189 """
177190 if encode_format not in ("raw" , "f32-to-bf16" ):
178191 raise ValueError (f"Invalie encode_format { encode_format } " )
179192
180- meta_data = {} if meta_data is None else meta_data
181193 records = []
182- total = len (params )
194+ from_generator = isinstance (params , GeneratorType )
195+ total_bytes = 0
183196 counter = 0
184197 max_out_length = 0
185198
@@ -193,14 +206,16 @@ def dump_ndarray_cache(
193206
194207 shard_manager = NDArrayCacheShardingManager (cache_dir , "params_shard" , shard_cap_nbytes )
195208
196- for k , origin_v in params .items ():
209+ param_generator = params .items () if not from_generator else params
210+ for k , origin_v in param_generator :
197211 shape = list (origin_v .shape )
198212 v = origin_v
199213 if not isinstance (v , np .ndarray ):
200214 v = v .numpy ()
201215
202216 # prefer to preserve original dtype, especially if the format was bfloat16
203217 dtype = str (origin_v .dtype ) if isinstance (origin_v , tvm .nd .NDArray ) else str (v .dtype )
218+ total_bytes += math .prod (v .shape ) * np .dtype (v .dtype ).itemsize
204219
205220 # convert fp32 to bf16
206221 if encode_format == "f32-to-bf16" and dtype == "float32" :
@@ -212,12 +227,14 @@ def dump_ndarray_cache(
212227 shard_manager .append (data , name = k , shape = shape , dtype = dtype , encode_format = encode_format )
213228
214229 counter += 1
215- last_cmd = "[%04d/%04d] saving %s" % (counter , total , k )
216- flush = "\r " + (" " * max_out_length ) + "\r "
217- max_out_length = max (len (last_cmd ), max_out_length )
218- sys .stdout .write (flush + last_cmd )
230+ if show_progress :
231+ last_cmd = "[%04d] saving %s" % (counter , k )
232+ flush = "\r " + (" " * max_out_length ) + "\r "
233+ max_out_length = max (len (last_cmd ), max_out_length )
234+ sys .stdout .write (flush + last_cmd )
219235
220236 records = shard_manager .finish ()
237+ meta_data = {} if meta_data is None else meta_data if not callable (meta_data ) else meta_data ()
221238
222239 nd_cache_json = os .path .join (cache_dir , "ndarray-cache.json" )
223240
0 commit comments