Skip to content

Commit fdecd51

Browse files
committed
Add save/load weights for policies and refactor action distributions
1 parent b782f3a commit fdecd51

File tree

11 files changed

+319
-211
lines changed

11 files changed

+319
-211
lines changed

docs/misc/changelog.rst

+4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ Pre-Release 0.4.0a0 (WIP)
1010
Breaking Changes:
1111
^^^^^^^^^^^^^^^^^
1212
- Removed CEMRL
13+
- Model saved with previous versions cannot be loaded (because of the pre-preprocessing)
1314

1415
New Features:
1516
^^^^^^^^^^^^^
1617
- Add support for Discrete observation spaces
18+
- Add saving/loading for policy weights, so the policy can be used without the model
1719

1820
Bug Fixes:
1921
^^^^^^^^^^
@@ -26,6 +28,8 @@ Others:
2628
^^^^^^^
2729
- Refactor handling of observation and action spaces
2830
- Refactored features extraction to have proper preprocessing
31+
- Refactored action distributions
32+
2933

3034
Documentation:
3135
^^^^^^^^^^^^^^

tests/test_distributions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def test_squashed_gaussian(model_class):
3838
gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
3939
dist = SquashedDiagGaussianDistribution(N_ACTIONS)
4040
_, log_std = dist.proba_distribution_net(N_FEATURES)
41-
actions, _ = dist.proba_distribution(gaussian_mean, log_std)
41+
dist = dist.proba_distribution(gaussian_mean, log_std)
42+
actions = dist.get_action()
4243
assert th.max(th.abs(actions)) <= 1.0
4344

4445
def test_sde_distribution():
@@ -51,7 +52,8 @@ def test_sde_distribution():
5152
_, log_std = dist.proba_distribution_net(N_FEATURES)
5253
dist.sample_weights(log_std, batch_size=N_SAMPLES)
5354

54-
actions, _ = dist.proba_distribution(deterministic_actions, log_std, state)
55+
dist = dist.proba_distribution(deterministic_actions, log_std, state)
56+
actions = dist.get_action()
5557

5658
assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=1e-3)
5759
assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3)
@@ -71,11 +73,12 @@ def test_entropy(dist):
7173
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))
7274

7375
if isinstance(dist, DiagGaussianDistribution):
74-
actions, dist = dist.proba_distribution(deterministic_actions, log_std)
76+
dist = dist.proba_distribution(deterministic_actions, log_std)
7577
else:
7678
dist.sample_weights(log_std, batch_size=N_SAMPLES)
77-
actions, dist = dist.proba_distribution(deterministic_actions, log_std, state)
79+
dist = dist.proba_distribution(deterministic_actions, log_std, state)
7880

81+
actions = dist.get_action()
7982
entropy = dist.entropy()
8083
log_prob = dist.log_prob(actions)
8184
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
@@ -88,8 +91,9 @@ def test_categorical():
8891
set_random_seed(1)
8992
state = th.rand(N_SAMPLES, N_FEATURES)
9093
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
91-
actions, dist = dist.proba_distribution(action_logits)
94+
dist = dist.proba_distribution(action_logits)
9295

96+
actions = dist.get_action()
9397
entropy = dist.entropy()
9498
log_prob = dist.log_prob(actions)
9599
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4)

tests/test_identity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_continuous(model_class):
2020
env = IdentityEnvBox(eps=0.5)
2121

2222
n_steps = {
23-
A2C: 3000,
23+
A2C: 3500,
2424
PPO: 3000,
2525
SAC: 700,
2626
TD3: 500

tests/test_save_load.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SAC,
1717
]
1818

19-
19+
#
2020
@pytest.mark.parametrize("model_class", MODEL_LIST)
2121
def test_save_load(model_class):
2222
"""
@@ -160,3 +160,61 @@ def test_save_load_replay_buffer(model_class):
160160

161161
# clear file from os
162162
os.remove(replay_path)
163+
164+
165+
@pytest.mark.parametrize("model_class", MODEL_LIST)
166+
def test_save_load_policy(model_class):
167+
"""
168+
Test saving and loading policy only.
169+
170+
:param model_class: (BaseRLModel) A RL model
171+
"""
172+
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
173+
174+
# create model
175+
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
176+
model.learn(total_timesteps=500, eval_freq=250)
177+
178+
env.reset()
179+
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
180+
observations = observations.reshape(10, -1)
181+
182+
policy = model.policy
183+
184+
# Get dictionary of current parameters
185+
params = deepcopy(policy.state_dict())
186+
187+
# Modify all parameters to be random values
188+
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
189+
190+
# Update model parameters with the new random values
191+
policy.load_state_dict(random_params)
192+
193+
new_params = policy.state_dict()
194+
# Check that all params are different now
195+
for k in params:
196+
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
197+
198+
params = new_params
199+
200+
# get selected actions
201+
selected_actions, _ = policy.predict(observations, deterministic=True)
202+
203+
# Save and load policy
204+
policy.save("./logs/policy_weights.pkl")
205+
# del policy
206+
policy.load("./logs/policy_weights.pkl")
207+
208+
# check if params are still the same after load
209+
new_params = policy.state_dict()
210+
211+
# Check that all params are the same as before save load procedure now
212+
for key in params:
213+
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
214+
215+
# check if model still selects the same actions
216+
new_selected_actions, _ = policy.predict(observations, deterministic=True)
217+
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
218+
219+
# clear file from os
220+
os.remove("./logs/policy_weights.pkl")

torchy_baselines/common/base_class.py

+4-122
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,6 @@ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
158158
assert eval_env.num_envs == 1
159159
return eval_env
160160

161-
def scale_action(self, action: np.ndarray) -> np.ndarray:
162-
"""
163-
Rescale the action from [low, high] to [-1, 1]
164-
(no need for symmetric action space)
165-
166-
:param action: (np.ndarray) Action to scale
167-
:return: (np.ndarray) Scaled action
168-
"""
169-
low, high = self.action_space.low, self.action_space.high
170-
return 2.0 * ((action - low) / (high - low)) - 1.0
171-
172-
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
173-
"""
174-
Rescale the action from [-1, 1] to [low, high]
175-
(no need for symmetric action space)
176-
177-
:param scaled_action: Action to un-scale
178-
"""
179-
low, high = self.action_space.low, self.action_space.high
180-
return low + (0.5 * (scaled_action + 1.0) * (high - low))
181-
182161
def _setup_lr_schedule(self) -> None:
183162
"""Transform to callable if needed."""
184163
self.lr_schedule = get_schedule_fn(self.learning_rate)
@@ -318,57 +297,6 @@ def learn(self, total_timesteps: int,
318297
"""
319298
raise NotImplementedError()
320299

321-
@staticmethod
322-
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
323-
"""
324-
For every observation type, detects and validates the shape,
325-
then returns whether or not the observation is vectorized.
326-
327-
:param observation: (np.ndarray) the input observation to validate
328-
:param observation_space: (gym.spaces) the observation space
329-
:return: (bool) whether the given observation is vectorized or not
330-
"""
331-
if isinstance(observation_space, gym.spaces.Box):
332-
if observation.shape == observation_space.shape:
333-
return False
334-
elif observation.shape[1:] == observation_space.shape:
335-
return True
336-
else:
337-
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
338-
"Box environment, please use {} ".format(observation_space.shape) +
339-
"or (n_env, {}) for the observation shape."
340-
.format(", ".join(map(str, observation_space.shape))))
341-
elif isinstance(observation_space, gym.spaces.Discrete):
342-
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
343-
return False
344-
elif len(observation.shape) == 1:
345-
return True
346-
else:
347-
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
348-
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
349-
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
350-
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
351-
# if observation.shape == (len(observation_space.nvec),):
352-
# return False
353-
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
354-
# return True
355-
# else:
356-
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
357-
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
358-
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
359-
# elif isinstance(observation_space, gym.spaces.MultiBinary):
360-
# if observation.shape == (observation_space.n,):
361-
# return False
362-
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
363-
# return True
364-
# else:
365-
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
366-
# "environment, please use ({},) or ".format(observation_space.n) +
367-
# "(n_env, {}) for the observation shape.".format(observation_space.n))
368-
else:
369-
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
370-
.format(observation_space))
371-
372300
def predict(self, observation: np.ndarray,
373301
state: Optional[np.ndarray] = None,
374302
mask: Optional[np.ndarray] = None,
@@ -383,36 +311,7 @@ def predict(self, observation: np.ndarray,
383311
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
384312
(used in recurrent policies)
385313
"""
386-
# TODO: move this block to BasePolicy
387-
# if state is None:
388-
# state = self.initial_state
389-
# if mask is None:
390-
# mask = [False for _ in range(self.n_envs)]
391-
observation = np.array(observation)
392-
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
393-
394-
observation = observation.reshape((-1,) + self.observation_space.shape)
395-
observation = th.as_tensor(observation).to(self.device)
396-
with th.no_grad():
397-
actions = self.policy.predict(observation, deterministic=deterministic)
398-
# Convert to numpy
399-
actions = actions.cpu().numpy()
400-
401-
# Rescale to proper domain when using squashing
402-
if isinstance(self.action_space, gym.spaces.Box) and self.policy.squash_output:
403-
actions = self.unscale_action(actions)
404-
405-
clipped_actions = actions
406-
# Clip the actions to avoid out of bound error when using gaussian distribution
407-
if isinstance(self.action_space, gym.spaces.Box) and not self.policy.squash_output:
408-
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
409-
410-
if not vectorized_env:
411-
if state is not None:
412-
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
413-
clipped_actions = clipped_actions[0]
414-
415-
return clipped_actions, state
314+
return self.policy.predict(observation, state, mask, deterministic)
416315

417316
@classmethod
418317
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
@@ -484,10 +383,7 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
484383
raise ValueError(f"Error: the file {load_path} could not be found")
485384

486385
# set device to cpu if cuda is not available
487-
if th.cuda.is_available():
488-
device = th.device('cuda')
489-
else:
490-
device = th.device('cpu')
386+
device = th.device('cuda') if th.cuda.is_available() else th.device('cpu')
491387

492388
# Open the zip archive and load data
493389
try:
@@ -534,20 +430,6 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
534430
# load the parameters with the right `map_location`
535431
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
536432

537-
# for backward compatibility
538-
if params.get('params') is not None:
539-
params_copy = {}
540-
for name in params:
541-
if name == 'params':
542-
params_copy['policy'] = params[name]
543-
elif name == 'opt':
544-
params_copy['policy.optimizer'] = params[name]
545-
# Special case for SAC
546-
elif name == 'ent_coef_optimizer':
547-
params_copy[name] = params[name]
548-
else:
549-
params_copy[name + '.optimizer'] = params[name]
550-
params = params_copy
551433
except zipfile.BadZipFile:
552434
# load_path wasn't a zip file
553435
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
@@ -925,7 +807,7 @@ def collect_rollouts(self,
925807
unscaled_action, _ = self.predict(obs, deterministic=False)
926808

927809
# Rescale the action from [low, high] to [-1, 1]
928-
scaled_action = self.scale_action(unscaled_action)
810+
scaled_action = self.policy.scale_action(unscaled_action)
929811

930812
if self.use_sde:
931813
# When using SDE, the action can be out of bounds
@@ -941,7 +823,7 @@ def collect_rollouts(self,
941823
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
942824

943825
# Rescale and perform action
944-
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
826+
new_obs, reward, done, infos = env.step(self.policy.unscale_action(clipped_action))
945827

946828
# Only stop training if return value is False, not when it is None.
947829
if callback.on_step() is False:

0 commit comments

Comments
 (0)