Skip to content

Commit b598f28

Browse files
authored
[Contrib] Implement NDArray cache update (#17029)
1 parent 7359313 commit b598f28

File tree

2 files changed

+94
-7
lines changed

2 files changed

+94
-7
lines changed

python/tvm/contrib/tvmjs.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# pylint: disable=unused-import
2525
import sys
2626
from types import GeneratorType
27-
from typing import Iterator, Mapping, Tuple, Union
27+
from typing import Any, Iterator, Mapping, Optional, Set, Tuple, Union
2828

2929
import numpy as np
3030

@@ -73,16 +73,31 @@ def _calculate_md5(filename):
7373
class NDArrayCacheShardingManager:
7474
"""Internal helper to shard ndarrays."""
7575

76-
def __init__(self, cache_dir: str, prefix: str, shard_cap_nbytes: int):
76+
def __init__(
77+
self,
78+
cache_dir: str,
79+
prefix: str,
80+
shard_cap_nbytes: int,
81+
initial_shard_records: Optional[Mapping[str, Any]] = None,
82+
):
7783
self.cache_dir = cache_dir
7884
self.prefix = prefix
7985
self.curr_records = []
8086
self.curr_data = bytearray()
8187
self.shard_records = []
8288
self.shard_cap_nbytes = shard_cap_nbytes
8389
self.counter = 0
90+
self.name_to_record: Mapping[str, Tuple[int, Mapping[str, Any]]] = {}
91+
self.updated_shards: Set[int] = set()
8492

85-
def append(self, data, name, shape, dtype, encode_format):
93+
if initial_shard_records is not None:
94+
self.shard_records = initial_shard_records
95+
self.counter = len(initial_shard_records)
96+
for idx, shard in enumerate(initial_shard_records):
97+
for rec in shard["records"]:
98+
self.name_to_record[rec["name"]] = (idx, rec)
99+
100+
def append_or_update(self, data, name, shape, dtype, encode_format, allow_update: bool = False):
86101
"""Commit a record to the manager.
87102
88103
Parameters
@@ -101,6 +116,9 @@ def append(self, data, name, shape, dtype, encode_format):
101116
102117
encode_format:
103118
The encode format of the entry
119+
120+
allow_update: bool
121+
If the record already exists, update the record. Otherwise, raise an error.
104122
"""
105123
rec = {
106124
"name": name,
@@ -109,6 +127,13 @@ def append(self, data, name, shape, dtype, encode_format):
109127
"format": encode_format,
110128
"nbytes": len(data),
111129
}
130+
if name in self.name_to_record:
131+
if not allow_update:
132+
raise ValueError(f"Duplicate name {name} found in the cache.")
133+
self.update_single_record(rec, data)
134+
return
135+
136+
self.name_to_record[name] = (self.counter, rec)
112137

113138
if self.pending_nbytes + len(data) >= self.shard_cap_nbytes:
114139
if len(data) * 2 >= self.shard_cap_nbytes:
@@ -121,6 +146,20 @@ def append(self, data, name, shape, dtype, encode_format):
121146
self.curr_records.append(rec)
122147
self.curr_data += data
123148

149+
def update_single_record(self, rec, data):
150+
"""Update a single record in a shard file."""
151+
name = rec["name"]
152+
idx, old_rec = self.name_to_record[name]
153+
if old_rec["nbytes"] != rec["nbytes"]:
154+
raise ValueError(f"Cannot update record {name}, size mismatch.")
155+
data_path = self.shard_records[idx]["dataPath"]
156+
full_path = os.path.join(self.cache_dir, data_path)
157+
with open(full_path, "r+b") as outfile:
158+
outfile.seek(old_rec["byteOffset"])
159+
outfile.write(data)
160+
self.name_to_record[name] = (idx, rec)
161+
self.updated_shards.add(idx)
162+
124163
def commit(self):
125164
"""Commit a record"""
126165
if self.pending_nbytes != 0:
@@ -131,6 +170,9 @@ def commit(self):
131170
def finish(self):
132171
"""Finish building and return shard records."""
133172
self.commit()
173+
for idx in self.updated_shards:
174+
full_path = os.path.join(self.cache_dir, self.shard_records[idx]["dataPath"])
175+
self.shard_records[idx]["md5sum"] = _calculate_md5(full_path)
134176
return self.shard_records
135177

136178
def _commit_internal(self, data, records):
@@ -165,6 +207,7 @@ def dump_ndarray_cache(
165207
meta_data=None,
166208
shard_cap_mb=32,
167209
show_progress: bool = True,
210+
update_if_exists: bool = False,
168211
):
169212
"""Dump parameters to NDArray cache.
170213
@@ -191,6 +234,10 @@ def dump_ndarray_cache(
191234
192235
show_progress: bool
193236
A boolean indicating if to show the dump progress.
237+
238+
update_if_exists: bool
239+
If the cache already exists, update the cache. When set to False, it will overwrite the
240+
existing files.
194241
"""
195242
if encode_format not in ("raw", "f32-to-bf16"):
196243
raise ValueError(f"Invalie encode_format {encode_format}")
@@ -209,7 +256,17 @@ def dump_ndarray_cache(
209256
print("Start storing to cache %s" % cache_dir)
210257
shard_cap_nbytes = shard_cap_mb * (1 << 20)
211258

212-
shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes)
259+
nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")
260+
if update_if_exists and os.path.exists(nd_cache_json):
261+
with open(nd_cache_json, "r") as infile:
262+
old_data = json.load(infile)
263+
if meta_data is None:
264+
meta_data = old_data["metadata"]
265+
records = old_data["records"]
266+
267+
shard_manager = NDArrayCacheShardingManager(
268+
cache_dir, "params_shard", shard_cap_nbytes, initial_shard_records=records
269+
)
213270

214271
param_generator = params.items() if not from_generator else params
215272
for k, origin_v in param_generator:
@@ -229,7 +286,14 @@ def dump_ndarray_cache(
229286
else:
230287
data = v.tobytes()
231288

232-
shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format)
289+
shard_manager.append_or_update(
290+
data,
291+
name=k,
292+
shape=shape,
293+
dtype=dtype,
294+
encode_format=encode_format,
295+
allow_update=update_if_exists,
296+
)
233297

234298
counter += 1
235299
if show_progress:
@@ -241,8 +305,6 @@ def dump_ndarray_cache(
241305
records = shard_manager.finish()
242306
meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data()
243307

244-
nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")
245-
246308
with open(nd_cache_json, "w") as outfile:
247309
json.dump({"metadata": meta_data, "records": records}, outfile, indent=4)
248310
print(

tests/python/relax/test_runtime_builtin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,31 @@ def test_ndarray_cache():
188188
np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6)
189189

190190

191+
def test_ndarray_cache_update():
192+
fload = tvm.get_global_func("vm.builtin.ndarray_cache.load")
193+
fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache")
194+
195+
param_dict = {
196+
"x_0": np.array([1, 2, 3], dtype="int32"),
197+
"x_1": np.random.uniform(size=[10, 20]).astype("float32"),
198+
}
199+
200+
temp = utils.tempdir()
201+
tvmjs.dump_ndarray_cache(param_dict, temp.path, encode_format="f32-to-bf16")
202+
param_dict["x_1"] = np.random.uniform(size=[10, 20]).astype("float32")
203+
param_dict["x_2"] = np.random.uniform(size=[10]).astype("float32")
204+
tvmjs.dump_ndarray_cache(
205+
param_dict, temp.path, encode_format="f32-to-bf16", update_if_exists=True
206+
)
207+
fload(str(temp.path), tvm.cpu().device_type, 0)
208+
res = fget_params("x", -1)
209+
for i, v in enumerate(res):
210+
v_np = param_dict[f"x_{i}"]
211+
if v_np.dtype == "float32":
212+
v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np))
213+
np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6)
214+
215+
191216
def test_attention_kv_cache_window_override():
192217
fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
193218
foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override")

0 commit comments

Comments
 (0)