Skip to content

Commit 9855486

Browse files
Miffyliaraffin
andauthored
Get/set parameters and review of saving and loading (#138)
* Update comments and docstrings * Rename get_torch_variables to private and update docs * Clarify documentation on data, params and tensors * Make excluded_save_params private and update docs * Update get_torch_variable_names to get_torch_save_params for description * Simplify saving code and update docs on params vs tensors * Rename saved item tensors to pytorch_variables for clarity * Reformat * Fix a typo * Add get/set_parameters, update tests accordingly * Use f-strings for formatting * Fix load docstring * Reorganize functions in BaseClass * Update changelog * Add library version to the stored models * Actually run isort this time * Fix flake8 complaints and also fix testing code * Fix isort * ...and black * Fix set_random_seed Co-authored-by: Antonin RAFFIN <[email protected]> Co-authored-by: Antonin Raffin <[email protected]>
1 parent 00595b0 commit 9855486

File tree

8 files changed

+420
-290
lines changed

8 files changed

+420
-290
lines changed

docs/misc/changelog.rst

+10
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ New Features:
1616
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
1717
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
1818
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
19+
- Added ``get_parameters`` and ``set_parameters`` for accessing/setting parameters of the agent
1920
- Added actor/critic loss logging for TD3. (@mloo3)
2021

2122
Bug Fixes:
2223
^^^^^^^^^^
2324
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
2425
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
2526
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
27+
- Fixed a bug when the random seed was not properly set on cuda when passing the GPU index
2628

2729
Deprecations:
2830
^^^^^^^^^^^^^
@@ -33,6 +35,14 @@ Others:
3335
- Fix type annotation of ``make_vec_env`` (@ManifoldFR)
3436
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used
3537
- Fixed typos in SAC and TD3
38+
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
39+
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
40+
- Reorganized functions for clarity in ``BaseClass`` (save/load functions close to each other, private
41+
functions at top)
42+
- Clarified docstrings on what is saved and loaded to/from files
43+
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
44+
- Simplified ``save_to_zip_file`` function by removing duplicate code
45+
- Store library version along with the saved models
3646

3747
Documentation:
3848
^^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

+285-187
Large diffs are not rendered by default.

stable_baselines3/common/on_policy_algorithm.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,7 @@ def learn(
240240

241241
return self
242242

243-
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
244-
"""
245-
cf base class
246-
"""
243+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
247244
state_dicts = ["policy", "policy.optimizer"]
248245

249246
return state_dicts, []

stable_baselines3/common/save_util.py

+34-38
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import cloudpickle
1717
import torch as th
1818

19+
import stable_baselines3
1920
from stable_baselines3.common.type_aliases import TensorDict
2021
from stable_baselines3.common.utils import get_device
2122

@@ -284,21 +285,20 @@ def save_to_zip_file(
284285
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
285286
data: Dict[str, Any] = None,
286287
params: Dict[str, Any] = None,
287-
tensors: Dict[str, Any] = None,
288+
pytorch_variables: Dict[str, Any] = None,
288289
verbose=0,
289290
) -> None:
290291
"""
291-
Save a model to a zip archive.
292+
Save model data to a zip archive.
292293
293294
:param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model.
294295
if save_path is a str or pathlib.Path ensures that the path actually exists.
295-
:param data: Class parameters being stored.
296+
:param data: Class parameters being stored (non-PyTorch variables)
296297
:param params: Model parameters being stored expected to contain an entry for every
297298
state_dict with its name and the state_dict.
298-
:param tensors: Extra tensor variables expected to contain name and value of tensors
299+
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
299300
:param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information
300301
"""
301-
302302
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
303303
# data/params can be None, so do not
304304
# try to serialize them blindly
@@ -310,13 +310,15 @@ def save_to_zip_file(
310310
# Do not try to save "None" elements
311311
if data is not None:
312312
archive.writestr("data", serialized_data)
313-
if tensors is not None:
314-
with archive.open("tensors.pth", mode="w") as tensors_file:
315-
th.save(tensors, tensors_file)
313+
if pytorch_variables is not None:
314+
with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file:
315+
th.save(pytorch_variables, pytorch_variables_file)
316316
if params is not None:
317317
for file_name, dict_ in params.items():
318318
with archive.open(file_name + ".pth", mode="w") as param_file:
319319
th.save(dict_, param_file)
320+
# Save metadata: library version when file was saved
321+
archive.writestr("_stable_baselines3_version", stable_baselines3.__version__)
320322

321323

322324
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None:
@@ -362,8 +364,8 @@ def load_from_zip_file(
362364
:param load_data: Whether we should load and return data
363365
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
364366
:param device: (Union[th.device, str]) Device on which the code should run.
365-
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
366-
and dict of extra tensors
367+
:return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict)
368+
and dict of pytorch variables
367369
"""
368370
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
369371

@@ -378,44 +380,38 @@ def load_from_zip_file(
378380
# zip archive, assume they were stored
379381
# as None (_save_to_file_zip allows this).
380382
data = None
381-
tensors = None
383+
pytorch_variables = None
382384
params = {}
383385

384386
if "data" in namelist and load_data:
385-
# Load class parameters and convert to string
387+
# Load class parameters that are stored
388+
# with either JSON or pickle (not PyTorch variables).
386389
json_data = archive.read("data").decode()
387390
data = json_to_data(json_data)
388391

389-
if "tensors.pth" in namelist and load_data:
390-
# Load extra tensors
391-
with archive.open("tensors.pth", mode="r") as tensor_file:
392-
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
392+
# Check for all .pth files and load them using th.load.
393+
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
394+
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
395+
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
396+
for file_path in pth_files:
397+
with archive.open(file_path, mode="r") as param_file:
398+
# File has to be seekable, but param_file is not, so load in BytesIO first
393399
# fixed in python >= 3.7
394400
file_content = io.BytesIO()
395-
file_content.write(tensor_file.read())
401+
file_content.write(param_file.read())
396402
# go to start of file
397403
file_content.seek(0)
398-
# load the parameters with the right ``map_location``
399-
tensors = th.load(file_content, map_location=device)
400-
401-
# check for all other .pth files
402-
other_files = [
403-
file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"
404-
]
405-
# if there are any other files which end with .pth and aren't "params.pth"
406-
# assume that they each are optimizer parameters
407-
if len(other_files) > 0:
408-
for file_path in other_files:
409-
with archive.open(file_path, mode="r") as opt_param_file:
410-
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
411-
# fixed in python >= 3.7
412-
file_content = io.BytesIO()
413-
file_content.write(opt_param_file.read())
414-
# go to start of file
415-
file_content.seek(0)
416-
# load the parameters with the right ``map_location``
417-
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
404+
# Load the parameters with the right ``map_location``.
405+
# Remove ".pth" ending with splitext
406+
th_object = th.load(file_content, map_location=device)
407+
if file_path == "pytorch_variables.pth":
408+
# PyTorch variables (not state_dicts)
409+
pytorch_variables = th_object
410+
else:
411+
# State dicts. Store into params dictionary
412+
# with same name as in .zip file (without .pth)
413+
params[os.path.splitext(file_path)[0]] = th_object
418414
except zipfile.BadZipFile:
419415
# load_path wasn't a zip file
420416
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
421-
return data, params, tensors
417+
return data, params, pytorch_variables

stable_baselines3/dqn/dqn.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -231,20 +231,10 @@ def learn(
231231
reset_num_timesteps=reset_num_timesteps,
232232
)
233233

234-
def excluded_save_params(self) -> List[str]:
235-
"""
236-
Returns the names of the parameters that should be excluded by default
237-
when saving the model.
238-
239-
:return: (List[str]) List of parameters that should be excluded from save
240-
"""
241-
# Exclude aliases
242-
return super(DQN, self).excluded_save_params() + ["q_net", "q_net_target"]
234+
def _excluded_save_params(self) -> List[str]:
235+
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
243236

244-
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
245-
"""
246-
cf base class
247-
"""
237+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
248238
state_dicts = ["policy", "policy.optimizer"]
249239

250240
return state_dicts, []

stable_baselines3/sac/sac.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -293,24 +293,14 @@ def learn(
293293
reset_num_timesteps=reset_num_timesteps,
294294
)
295295

296-
def excluded_save_params(self) -> List[str]:
297-
"""
298-
Returns the names of the parameters that should be excluded by default
299-
when saving the model.
300-
301-
:return: (List[str]) List of parameters that should be excluded from save
302-
"""
303-
# Exclude aliases
304-
return super(SAC, self).excluded_save_params() + ["actor", "critic", "critic_target"]
305-
306-
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
307-
"""
308-
cf base class
309-
"""
296+
def _excluded_save_params(self) -> List[str]:
297+
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
298+
299+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
310300
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
311-
saved_tensors = ["log_ent_coef"]
301+
saved_pytorch_variables = ["log_ent_coef"]
312302
if self.ent_coef_optimizer is not None:
313303
state_dicts.append("ent_coef_optimizer")
314304
else:
315-
saved_tensors.append("ent_coef_tensor")
316-
return state_dicts, saved_tensors
305+
saved_pytorch_variables.append("ent_coef_tensor")
306+
return state_dicts, saved_pytorch_variables

stable_baselines3/td3/td3.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,9 @@ def learn(
205205
reset_num_timesteps=reset_num_timesteps,
206206
)
207207

208-
def excluded_save_params(self) -> List[str]:
209-
"""
210-
Returns the names of the parameters that should be excluded by default
211-
when saving the model.
212-
213-
:return: (List[str]) List of parameters that should be excluded from save
214-
"""
215-
# Exclude aliases
216-
return super(TD3, self).excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
217-
218-
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
219-
"""
220-
cf base class
221-
"""
208+
def _excluded_save_params(self) -> List[str]:
209+
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
210+
211+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
222212
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
223213
return state_dicts, []

tests/test_save_load.py

+76-17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import pathlib
44
import warnings
5+
from collections import OrderedDict
56
from copy import deepcopy
67

78
import gym
@@ -33,7 +34,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
3334
def test_save_load(tmp_path, model_class):
3435
"""
3536
Test if 'save' and 'load' saves and loads model correctly
36-
and if 'load_parameters' and 'get_policy_parameters' work correctly
37+
and if 'get_parameters' and 'set_parameters' and work correctly.
3738
3839
''warning does not test function of optimizer parameter load
3940
@@ -49,19 +50,73 @@ def test_save_load(tmp_path, model_class):
4950
env.reset()
5051
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
5152

52-
# Get dictionary of current parameters
53-
params = deepcopy(model.policy.state_dict())
53+
# Get parameters of different objects
54+
# deepcopy to avoid referencing to tensors we are about to modify
55+
original_params = deepcopy(model.get_parameters())
5456

55-
# Modify all parameters to be random values
56-
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
57+
# Test different error cases of set_parameters.
58+
# Test that invalid object names throw errors
59+
invalid_object_params = deepcopy(original_params)
60+
invalid_object_params["I_should_not_be_a_valid_object"] = "and_I_am_an_invalid_tensor"
61+
with pytest.raises(ValueError):
62+
model.set_parameters(invalid_object_params, exact_match=True)
63+
with pytest.raises(ValueError):
64+
model.set_parameters(invalid_object_params, exact_match=False)
5765

58-
# Update model parameters with the new random values
59-
model.policy.load_state_dict(random_params)
66+
# Test that exact_match catches when something was missed.
67+
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
68+
with pytest.raises(ValueError):
69+
model.set_parameters(missing_object_params, exact_match=True)
70+
71+
# Test that exact_match catches when something inside state-dict
72+
# is missing but we have exact_match.
73+
missing_state_dict_tensor_params = {}
74+
for object_name in original_params:
75+
object_params = {}
76+
missing_state_dict_tensor_params[object_name] = object_params
77+
# Skip last item in state-dict
78+
for k, v in list(original_params[object_name].items())[:-1]:
79+
object_params[k] = v
80+
with pytest.raises(RuntimeError):
81+
# PyTorch load_state_dict throws RuntimeError if strict but
82+
# invalid state-dict.
83+
model.set_parameters(missing_state_dict_tensor_params, exact_match=True)
84+
85+
# Test that parameters do indeed change.
86+
random_params = {}
87+
for object_name, params in original_params.items():
88+
# Do not randomize optimizer parameters (custom layout)
89+
if "optim" in object_name:
90+
random_params[object_name] = params
91+
else:
92+
# Again, skip the last item in state-dict
93+
random_params[object_name] = OrderedDict(
94+
(param_name, th.rand_like(param)) for param_name, param in list(params.items())[:-1]
95+
)
6096

61-
new_params = model.policy.state_dict()
62-
# Check that all params are different now
63-
for k in params:
64-
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
97+
# Update model parameters with the new random values
98+
model.set_parameters(random_params, exact_match=False)
99+
100+
new_params = model.get_parameters()
101+
# Check that all params except the final item in each state-dict are different.
102+
for object_name in original_params:
103+
# Skip optimizers (no valid comparison with just th.allclose)
104+
if "optim" in object_name:
105+
continue
106+
# state-dicts use ordered dictionaries, so key order
107+
# is guaranteed.
108+
last_key = list(original_params[object_name].keys())[-1]
109+
for k in original_params[object_name]:
110+
if k == last_key:
111+
# Should be same as before
112+
assert th.allclose(
113+
original_params[object_name][k], new_params[object_name][k]
114+
), "Parameter changed despite not included in the loaded parameters."
115+
else:
116+
# Should be different
117+
assert not th.allclose(
118+
original_params[object_name][k], new_params[object_name][k]
119+
), "Parameters did not change as expected."
65120

66121
params = new_params
67122

@@ -81,14 +136,18 @@ def test_save_load(tmp_path, model_class):
81136
assert model.policy.device.type == get_device(device).type
82137

83138
# check if params are still the same after load
84-
new_params = model.policy.state_dict()
139+
new_params = model.get_parameters()
85140

86141
# Check that all params are the same as before save load procedure now
87-
for key in params:
88-
assert new_params[key].device.type == get_device(device).type
89-
assert th.allclose(
90-
params[key].to("cpu"), new_params[key].to("cpu")
91-
), "Model parameters not the same after save and load."
142+
for object_name in new_params:
143+
# Skip optimizers (no valid comparison with just th.allclose)
144+
if "optim" in object_name:
145+
continue
146+
for key in params[object_name]:
147+
assert new_params[object_name][key].device.type == get_device(device).type
148+
assert th.allclose(
149+
params[object_name][key].to("cpu"), new_params[object_name][key].to("cpu")
150+
), "Model parameters not the same after save and load."
92151

93152
# check if model still selects the same actions
94153
new_selected_actions, _ = model.predict(observations, deterministic=True)

0 commit comments

Comments
 (0)