2424# pylint: disable=unused-import
2525import sys
2626from types import GeneratorType
27- from typing import Iterator , Mapping , Tuple , Union
27+ from typing import Any , Iterator , Mapping , Optional , Set , Tuple , Union
2828
2929import numpy as np
3030
@@ -73,16 +73,31 @@ def _calculate_md5(filename):
7373class NDArrayCacheShardingManager :
7474 """Internal helper to shard ndarrays."""
7575
76- def __init__ (self , cache_dir : str , prefix : str , shard_cap_nbytes : int ):
76+ def __init__ (
77+ self ,
78+ cache_dir : str ,
79+ prefix : str ,
80+ shard_cap_nbytes : int ,
81+ initial_shard_records : Optional [Mapping [str , Any ]] = None ,
82+ ):
7783 self .cache_dir = cache_dir
7884 self .prefix = prefix
7985 self .curr_records = []
8086 self .curr_data = bytearray ()
8187 self .shard_records = []
8288 self .shard_cap_nbytes = shard_cap_nbytes
8389 self .counter = 0
90+ self .name_to_record : Mapping [str , Tuple [int , Mapping [str , Any ]]] = {}
91+ self .updated_shards : Set [int ] = set ()
8492
85- def append (self , data , name , shape , dtype , encode_format ):
93+ if initial_shard_records is not None :
94+ self .shard_records = initial_shard_records
95+ self .counter = len (initial_shard_records )
96+ for idx , shard in enumerate (initial_shard_records ):
97+ for rec in shard ["records" ]:
98+ self .name_to_record [rec ["name" ]] = (idx , rec )
99+
100+ def append_or_update (self , data , name , shape , dtype , encode_format , allow_update : bool = False ):
86101 """Commit a record to the manager.
87102
88103 Parameters
@@ -101,6 +116,9 @@ def append(self, data, name, shape, dtype, encode_format):
101116
102117 encode_format:
103118 The encode format of the entry
119+
120+ allow_update: bool
121+ If the record already exists, update the record. Otherwise, raise an error.
104122 """
105123 rec = {
106124 "name" : name ,
@@ -109,6 +127,13 @@ def append(self, data, name, shape, dtype, encode_format):
109127 "format" : encode_format ,
110128 "nbytes" : len (data ),
111129 }
130+ if name in self .name_to_record :
131+ if not allow_update :
132+ raise ValueError (f"Duplicate name { name } found in the cache." )
133+ self .update_single_record (rec , data )
134+ return
135+
136+ self .name_to_record [name ] = (self .counter , rec )
112137
113138 if self .pending_nbytes + len (data ) >= self .shard_cap_nbytes :
114139 if len (data ) * 2 >= self .shard_cap_nbytes :
@@ -121,6 +146,20 @@ def append(self, data, name, shape, dtype, encode_format):
121146 self .curr_records .append (rec )
122147 self .curr_data += data
123148
149+ def update_single_record (self , rec , data ):
150+ """Update a single record in a shard file."""
151+ name = rec ["name" ]
152+ idx , old_rec = self .name_to_record [name ]
153+ if old_rec ["nbytes" ] != rec ["nbytes" ]:
154+ raise ValueError (f"Cannot update record { name } , size mismatch." )
155+ data_path = self .shard_records [idx ]["dataPath" ]
156+ full_path = os .path .join (self .cache_dir , data_path )
157+ with open (full_path , "r+b" ) as outfile :
158+ outfile .seek (old_rec ["byteOffset" ])
159+ outfile .write (data )
160+ self .name_to_record [name ] = (idx , rec )
161+ self .updated_shards .add (idx )
162+
124163 def commit (self ):
125164 """Commit a record"""
126165 if self .pending_nbytes != 0 :
@@ -131,6 +170,9 @@ def commit(self):
131170 def finish (self ):
132171 """Finish building and return shard records."""
133172 self .commit ()
173+ for idx in self .updated_shards :
174+ full_path = os .path .join (self .cache_dir , self .shard_records [idx ]["dataPath" ])
175+ self .shard_records [idx ]["md5sum" ] = _calculate_md5 (full_path )
134176 return self .shard_records
135177
136178 def _commit_internal (self , data , records ):
@@ -165,6 +207,7 @@ def dump_ndarray_cache(
165207 meta_data = None ,
166208 shard_cap_mb = 32 ,
167209 show_progress : bool = True ,
210+ update_if_exists : bool = False ,
168211):
169212 """Dump parameters to NDArray cache.
170213
@@ -191,6 +234,10 @@ def dump_ndarray_cache(
191234
192235 show_progress: bool
193236 A boolean indicating if to show the dump progress.
237+
238+ update_if_exists: bool
239+ If the cache already exists, update the cache. When set to False, it will overwrite the
240+ existing files.
194241 """
195242 if encode_format not in ("raw" , "f32-to-bf16" ):
196243 raise ValueError (f"Invalie encode_format { encode_format } " )
@@ -209,7 +256,17 @@ def dump_ndarray_cache(
209256 print ("Start storing to cache %s" % cache_dir )
210257 shard_cap_nbytes = shard_cap_mb * (1 << 20 )
211258
212- shard_manager = NDArrayCacheShardingManager (cache_dir , "params_shard" , shard_cap_nbytes )
259+ nd_cache_json = os .path .join (cache_dir , "ndarray-cache.json" )
260+ if update_if_exists and os .path .exists (nd_cache_json ):
261+ with open (nd_cache_json , "r" ) as infile :
262+ old_data = json .load (infile )
263+ if meta_data is None :
264+ meta_data = old_data ["metadata" ]
265+ records = old_data ["records" ]
266+
267+ shard_manager = NDArrayCacheShardingManager (
268+ cache_dir , "params_shard" , shard_cap_nbytes , initial_shard_records = records
269+ )
213270
214271 param_generator = params .items () if not from_generator else params
215272 for k , origin_v in param_generator :
@@ -229,7 +286,14 @@ def dump_ndarray_cache(
229286 else :
230287 data = v .tobytes ()
231288
232- shard_manager .append (data , name = k , shape = shape , dtype = dtype , encode_format = encode_format )
289+ shard_manager .append_or_update (
290+ data ,
291+ name = k ,
292+ shape = shape ,
293+ dtype = dtype ,
294+ encode_format = encode_format ,
295+ allow_update = update_if_exists ,
296+ )
233297
234298 counter += 1
235299 if show_progress :
@@ -241,8 +305,6 @@ def dump_ndarray_cache(
241305 records = shard_manager .finish ()
242306 meta_data = {} if meta_data is None else meta_data if not callable (meta_data ) else meta_data ()
243307
244- nd_cache_json = os .path .join (cache_dir , "ndarray-cache.json" )
245-
246308 with open (nd_cache_json , "w" ) as outfile :
247309 json .dump ({"metadata" : meta_data , "records" : records }, outfile , indent = 4 )
248310 print (
0 commit comments