Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added get/set_param functions to RL algos to support fast pbt #251

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ Additional environment supported properties and functions
* Fixed bug with SAC not saving weights with save_frequency.
* Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation.
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.

1.6.0

Expand Down
9 changes: 4 additions & 5 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

from torch import optim
import torch
from torch import nn
import numpy as np
import gym


class A2CAgent(a2c_common.ContinuousA2CBase):

def __init__(self, base_name, params):
a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
Expand Down Expand Up @@ -68,9 +67,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
processed_obs = self._preproc_obs(obs['obs'])
Expand Down
4 changes: 4 additions & 0 deletions rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def rescale_actions(low, high, action):


class PpoPlayerContinuous(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down Expand Up @@ -81,7 +82,9 @@ def restore(self, fn):
def reset(self):
self.init_rnn()


class PpoPlayerDiscrete(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)

Expand Down Expand Up @@ -185,6 +188,7 @@ def reset(self):


class SACPlayer(BasePlayer):

def __init__(self, params):
BasePlayer.__init__(self, params)
self.network = self.config['network']
Expand Down
89 changes: 57 additions & 32 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, base_name, params):
self.num_steps_per_episode = config.get("num_steps_per_episode", 1)
self.normalize_input = config.get("normalize_input", False)

# TODO: double-check! To use bootstrap instead?
self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach

print(self.batch_size, self.num_actors, self.num_agents)
Expand All @@ -60,7 +61,6 @@ def __init__(self, base_name, params):
'action_dim': self.env_info["action_space"].shape[0],
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'normalize_input' : self.normalize_input,
'normalize_input': self.normalize_input,
}
self.model = self.network.build(net_config)
Expand Down Expand Up @@ -88,12 +88,8 @@ def __init__(self, base_name, params):
self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0]
print("Target entropy", self.target_entropy)

self.step = 0
self.algo_observer = config['features']['observer']

# TODO: Is there a better way to get the maximum number of episodes?
self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode

def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
Expand Down Expand Up @@ -133,6 +129,8 @@ def base_init(self, base_name, config):
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)

self.save_freq = config.get('save_frequency', 0)

self.network = config['network']
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
Expand All @@ -146,10 +144,10 @@ def base_init(self, base_name, config):
self.min_alpha = torch.tensor(np.log(1)).float().to(self._device)

self.frame = 0
self.epoch_num = 0
self.update_time = 0
self.last_mean_rewards = -100500
self.last_mean_rewards = -1000000000
self.play_time = 0
self.epoch_num = 0

# TODO: put it into the separate class
pbt_str = ''
Expand Down Expand Up @@ -205,17 +203,8 @@ def alpha(self):
def device(self):
return self._device

def get_full_state_weights(self):
state = self.get_weights()

state['steps'] = self.step
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def get_weights(self):
print("Loading weights")
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
'critic_target': self.model.sac_network.critic_target.state_dict()}
Expand All @@ -233,17 +222,45 @@ def set_weights(self, weights):
if self.normalize_input and 'running_mean_std' in weights:
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])

def set_full_state_weights(self, weights):
def get_full_state_weights(self):
print("Loading full weights")
state = self.get_weights()

state['epoch'] = self.epoch_num
state['frame'] = self.frame
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def set_full_state_weights(self, weights, set_epoch=True):
self.set_weights(weights)

self.step = weights['step']
if set_epoch:
self.epoch_num = weights['epoch']
self.frame = weights['frame']

self.actor_optimizer.load_state_dict(weights['actor_optimizer'])
self.critic_optimizer.load_state_dict(weights['critic_optimizer'])
self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer'])

def restore(self, fn):
self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000)

if self.vec_env is not None:
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def restore(self, fn, set_epoch=True):
print("SAC restore")
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_param(self, param_name):
pass

def set_param(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down Expand Up @@ -334,6 +351,7 @@ def preproc_obs(self, obs):
if isinstance(obs, dict):
obs = obs['obs']
obs = self.model.norm_obs(obs)

return obs

def cast_obs(self, obs):
Expand All @@ -348,7 +366,7 @@ def cast_obs(self, obs):

return obs

# todo: move to common utils
# TODO: move to common utils
def obs_to_tensors(self, obs):
obs_is_dict = isinstance(obs, dict)
if obs_is_dict:
Expand All @@ -359,6 +377,7 @@ def obs_to_tensors(self, obs):
upd_obs = self.cast_obs(obs)
if not obs_is_dict or 'obs' not in obs:
upd_obs = {'obs' : upd_obs}

return upd_obs

def _obs_to_tensors_internal(self, obs):
Expand All @@ -368,18 +387,19 @@ def _obs_to_tensors_internal(self, obs):
upd_obs[key] = self._obs_to_tensors_internal(value)
else:
upd_obs = self.cast_obs(obs)

return upd_obs

def preprocess_actions(self, actions):
if not self.is_tensor_obses:
actions = actions.cpu().numpy()

return actions

def env_step(self, actions):
actions = self.preprocess_actions(actions)
obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space)

self.step += self.num_actors
if self.is_tensor_obses:
return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos
else:
Expand Down Expand Up @@ -415,7 +435,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act
def clear_stats(self):
self.game_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
self.mean_rewards = self.last_mean_rewards = -1000000000
self.algo_observer.after_clear_stats()

def play_steps(self, random_exploration = False):
Expand All @@ -431,6 +451,11 @@ def play_steps(self, random_exploration = False):
critic2_losses = []

obs = self.obs
if isinstance(obs, dict):
obs = self.obs['obs']

next_obs_processed = obs.clone()

for s in range(self.num_steps_per_episode):
self.set_eval()
if random_exploration:
Expand Down Expand Up @@ -466,16 +491,17 @@ def play_steps(self, random_exploration = False):
self.current_rewards = self.current_rewards * not_dones
self.current_lengths = self.current_lengths * not_dones

if isinstance(obs, dict):
obs = obs['obs']
if isinstance(next_obs, dict):
next_obs = next_obs['obs']
next_obs_processed = next_obs['obs']

self.obs = next_obs.clone()

rewards = self.rewards_shaper(rewards)

self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1))
self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1))

self.obs = obs = next_obs.clone()
if isinstance(obs, dict):
obs = self.obs['obs']

if not random_exploration:
self.set_train()
Expand Down Expand Up @@ -505,10 +531,9 @@ def train_epoch(self):
def train(self):
self.init_tensors()
self.algo_observer.after_init(self)
self.last_mean_rewards = -100500
total_time = 0
# rep_count = 0
self.frame = 0

self.obs = self.env_reset()

while True:
Expand Down Expand Up @@ -560,7 +585,7 @@ def train(self):
should_exit = False

if self.save_freq > 0:
if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if self.epoch_num % self.save_freq == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after:
Expand Down
Loading