Skip to content

Commit 16a83c8

Browse files
committed
[Contrib] Support NDArray cache taking generator
This PR enhances the `dump_ndarray_cache` function to take generator as input. Previously it can only take a dictionary. Sometimes, it is possible that the total ndarray size cannot fit the main CPU memory, in which case we may turn to using generators so we can free some NDArray memory on the fly. And this PR supports the NDArray cache dumping with generators.
1 parent 40dd376 commit 16a83c8

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

python/tvm/contrib/tvmjs.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
"""Namespace to store utilities for building web runtime."""
1818
import hashlib
1919
import json
20+
import math
2021
import os
2122
import shutil
2223

2324
# pylint: disable=unused-import
2425
import sys
25-
from typing import Mapping, Union
26+
from types import GeneratorType
27+
from typing import Iterator, Mapping, Tuple, Union
2628

2729
import numpy as np
2830

@@ -149,37 +151,48 @@ def pending_nbytes(self):
149151

150152

151153
def dump_ndarray_cache(
152-
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
154+
params: Union[
155+
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
156+
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
157+
],
153158
cache_dir: str,
154159
encode_format="f32-to-bf16",
155160
meta_data=None,
156161
shard_cap_mb=32,
162+
show_progress: bool = True,
157163
):
158164
"""Dump parameters to NDArray cache.
159165
160166
Parameters
161167
----------
162-
params: Mapping[str, tvm.runtime.NDArray],
163-
The parameter dictionary
168+
params: Union[
169+
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
170+
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
171+
]
172+
The parameter dictionary or generator
164173
165174
cache_dir: str
166175
The path to the cache
167176
168177
encode_format: {"f32-to-bf16", "raw"}
169178
Encoding format.
170179
171-
meta_data: json-compatible-struct
172-
Extra meta_data to be stored in the cache json file.
180+
meta_data: json-compatible-struct or Callable[[], Any]
181+
Extra meta_data to be stored in the cache json file,
182+
or a callable that returns the metadata.
173183
174184
shard_cap_mb: int
175185
Maxinum number of MB to be kept per shard
186+
187+
show_progress: bool
188+
A boolean indicating if to show the dump progress.
176189
"""
177190
if encode_format not in ("raw", "f32-to-bf16"):
178191
raise ValueError(f"Invalie encode_format {encode_format}")
179192

180-
meta_data = {} if meta_data is None else meta_data
181193
records = []
182-
total = len(params)
194+
from_generator = isinstance(params, GeneratorType)
195+
total_bytes = 0
183196
counter = 0
184197
max_out_length = 0
185198

@@ -193,14 +206,16 @@ def dump_ndarray_cache(
193206

194207
shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes)
195208

196-
for k, origin_v in params.items():
209+
param_generator = params.items() if not from_generator else params
210+
for k, origin_v in param_generator:
197211
shape = list(origin_v.shape)
198212
v = origin_v
199213
if not isinstance(v, np.ndarray):
200214
v = v.numpy()
201215

202216
# prefer to preserve original dtype, especially if the format was bfloat16
203217
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype)
218+
total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize
204219

205220
# convert fp32 to bf16
206221
if encode_format == "f32-to-bf16" and dtype == "float32":
@@ -212,12 +227,14 @@ def dump_ndarray_cache(
212227
shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format)
213228

214229
counter += 1
215-
last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
216-
flush = "\r" + (" " * max_out_length) + "\r"
217-
max_out_length = max(len(last_cmd), max_out_length)
218-
sys.stdout.write(flush + last_cmd)
230+
if show_progress:
231+
last_cmd = "[%04d] saving %s" % (counter, k)
232+
flush = "\r" + (" " * max_out_length) + "\r"
233+
max_out_length = max(len(last_cmd), max_out_length)
234+
sys.stdout.write(flush + last_cmd)
219235

220236
records = shard_manager.finish()
237+
meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data()
221238

222239
nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")
223240

0 commit comments

Comments
 (0)