16
16
import cloudpickle
17
17
import torch as th
18
18
19
+ import stable_baselines3
19
20
from stable_baselines3 .common .type_aliases import TensorDict
20
21
from stable_baselines3 .common .utils import get_device
21
22
@@ -284,21 +285,20 @@ def save_to_zip_file(
284
285
save_path : Union [str , pathlib .Path , io .BufferedIOBase ],
285
286
data : Dict [str , Any ] = None ,
286
287
params : Dict [str , Any ] = None ,
287
- tensors : Dict [str , Any ] = None ,
288
+ pytorch_variables : Dict [str , Any ] = None ,
288
289
verbose = 0 ,
289
290
) -> None :
290
291
"""
291
- Save a model to a zip archive.
292
+ Save model data to a zip archive.
292
293
293
294
:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
294
295
if save_path is a str or pathlib.Path ensures that the path actually exists.
295
- :param data: Class parameters being stored.
296
+ :param data: Class parameters being stored (non-PyTorch variables)
296
297
:param params: Model parameters being stored expected to contain an entry for every
297
298
state_dict with its name and the state_dict.
298
- :param tensors: Extra tensor variables expected to contain name and value of tensors
299
+ :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
299
300
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
300
301
"""
301
-
302
302
save_path = open_path (save_path , "w" , verbose = 0 , suffix = "zip" )
303
303
# data/params can be None, so do not
304
304
# try to serialize them blindly
@@ -310,13 +310,15 @@ def save_to_zip_file(
310
310
# Do not try to save "None" elements
311
311
if data is not None :
312
312
archive .writestr ("data" , serialized_data )
313
- if tensors is not None :
314
- with archive .open ("tensors .pth" , mode = "w" ) as tensors_file :
315
- th .save (tensors , tensors_file )
313
+ if pytorch_variables is not None :
314
+ with archive .open ("pytorch_variables .pth" , mode = "w" ) as pytorch_variables_file :
315
+ th .save (pytorch_variables , pytorch_variables_file )
316
316
if params is not None :
317
317
for file_name , dict_ in params .items ():
318
318
with archive .open (file_name + ".pth" , mode = "w" ) as param_file :
319
319
th .save (dict_ , param_file )
320
+ # Save metadata: library version when file was saved
321
+ archive .writestr ("_stable_baselines3_version" , stable_baselines3 .__version__ )
320
322
321
323
322
324
def save_to_pkl (path : Union [str , pathlib .Path , io .BufferedIOBase ], obj , verbose = 0 ) -> None :
@@ -362,8 +364,8 @@ def load_from_zip_file(
362
364
:param load_data: Whether we should load and return data
363
365
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
364
366
:param device: (Union[th.device, str]) Device on which the code should run.
365
- :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
366
- and dict of extra tensors
367
+ :return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict)
368
+ and dict of pytorch variables
367
369
"""
368
370
load_path = open_path (load_path , "r" , verbose = verbose , suffix = "zip" )
369
371
@@ -378,44 +380,38 @@ def load_from_zip_file(
378
380
# zip archive, assume they were stored
379
381
# as None (_save_to_file_zip allows this).
380
382
data = None
381
- tensors = None
383
+ pytorch_variables = None
382
384
params = {}
383
385
384
386
if "data" in namelist and load_data :
385
- # Load class parameters and convert to string
387
+ # Load class parameters that are stored
388
+ # with either JSON or pickle (not PyTorch variables).
386
389
json_data = archive .read ("data" ).decode ()
387
390
data = json_to_data (json_data )
388
391
389
- if "tensors.pth" in namelist and load_data :
390
- # Load extra tensors
391
- with archive .open ("tensors.pth" , mode = "r" ) as tensor_file :
392
- # File has to be seekable, but opt_param_file is not, so load in BytesIO first
392
+ # Check for all .pth files and load them using th.load.
393
+ # "pytorch_variables.pth" stores PyTorch variables, and any other .pth
394
+ # files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
395
+ pth_files = [file_name for file_name in namelist if os .path .splitext (file_name )[1 ] == ".pth" ]
396
+ for file_path in pth_files :
397
+ with archive .open (file_path , mode = "r" ) as param_file :
398
+ # File has to be seekable, but param_file is not, so load in BytesIO first
393
399
# fixed in python >= 3.7
394
400
file_content = io .BytesIO ()
395
- file_content .write (tensor_file .read ())
401
+ file_content .write (param_file .read ())
396
402
# go to start of file
397
403
file_content .seek (0 )
398
- # load the parameters with the right ``map_location``
399
- tensors = th .load (file_content , map_location = device )
400
-
401
- # check for all other .pth files
402
- other_files = [
403
- file_name for file_name in namelist if os .path .splitext (file_name )[1 ] == ".pth" and file_name != "tensors.pth"
404
- ]
405
- # if there are any other files which end with .pth and aren't "params.pth"
406
- # assume that they each are optimizer parameters
407
- if len (other_files ) > 0 :
408
- for file_path in other_files :
409
- with archive .open (file_path , mode = "r" ) as opt_param_file :
410
- # File has to be seekable, but opt_param_file is not, so load in BytesIO first
411
- # fixed in python >= 3.7
412
- file_content = io .BytesIO ()
413
- file_content .write (opt_param_file .read ())
414
- # go to start of file
415
- file_content .seek (0 )
416
- # load the parameters with the right ``map_location``
417
- params [os .path .splitext (file_path )[0 ]] = th .load (file_content , map_location = device )
404
+ # Load the parameters with the right ``map_location``.
405
+ # Remove ".pth" ending with splitext
406
+ th_object = th .load (file_content , map_location = device )
407
+ if file_path == "pytorch_variables.pth" :
408
+ # PyTorch variables (not state_dicts)
409
+ pytorch_variables = th_object
410
+ else :
411
+ # State dicts. Store into params dictionary
412
+ # with same name as in .zip file (without .pth)
413
+ params [os .path .splitext (file_path )[0 ]] = th_object
418
414
except zipfile .BadZipFile :
419
415
# load_path wasn't a zip file
420
416
raise ValueError (f"Error: the file { load_path } wasn't a zip-file" )
421
- return data , params , tensors
417
+ return data , params , pytorch_variables
0 commit comments