Skip to content

Commit

Permalink
Added get/set_param functions to support fast pbt, without restarting…
Browse files Browse the repository at this point in the history
… training from scratch.
  • Loading branch information
ViktorM committed Aug 22, 2023
1 parent 990b478 commit 289e32f
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 42 deletions.
9 changes: 4 additions & 5 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

from torch import optim
import torch
from torch import nn
import numpy as np
import gym


class A2CAgent(a2c_common.ContinuousA2CBase):

def __init__(self, base_name, params):
a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
Expand Down Expand Up @@ -68,9 +67,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_masked_action_values(self, obs, action_masks):
processed_obs = self._preproc_obs(obs['obs'])
Expand Down
59 changes: 34 additions & 25 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, base_name, params):
self.num_steps_per_episode = config.get("num_steps_per_episode", 1)
self.normalize_input = config.get("normalize_input", False)

# TODO: double-check! To use bootstrap instead?
self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach

print(self.batch_size, self.num_actors, self.num_agents)
Expand Down Expand Up @@ -88,12 +89,8 @@ def __init__(self, base_name, params):
self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0]
print("Target entropy", self.target_entropy)

self.step = 0
self.algo_observer = config['features']['observer']

# TODO: Is there a better way to get the maximum number of episodes?
self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode

def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
Expand Down Expand Up @@ -146,10 +143,10 @@ def base_init(self, base_name, config):
self.min_alpha = torch.tensor(np.log(1)).float().to(self._device)

self.frame = 0
self.epoch_num = 0
self.update_time = 0
self.last_mean_rewards = -100500
self.last_mean_rewards = -1000000000
self.play_time = 0
self.epoch_num = 0

# TODO: put it into the separate class
pbt_str = ''
Expand Down Expand Up @@ -205,16 +202,6 @@ def alpha(self):
def device(self):
return self._device

def get_full_state_weights(self):
state = self.get_weights()

state['steps'] = self.step
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def get_weights(self):
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
Expand All @@ -233,17 +220,37 @@ def set_weights(self, weights):
if self.normalize_input and 'running_mean_std' in weights:
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])

def set_full_state_weights(self, weights):
def get_full_state_weights(self):
state = self.get_weights()

state['epoch'] = self.epoch_num
state['frame'] = self.frame
state['actor_optimizer'] = self.actor_optimizer.state_dict()
state['critic_optimizer'] = self.critic_optimizer.state_dict()
state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict()

return state

def set_full_state_weights(self, weights, set_epoch=True):
self.set_weights(weights)

self.step = weights['step']
if set_epoch:
self.epoch_num = weights['epoch']
self.frame = weights['frame']

self.actor_optimizer.load_state_dict(weights['actor_optimizer'])
self.critic_optimizer.load_state_dict(weights['critic_optimizer'])
self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer'])

def restore(self, fn):
def restore(self, fn, set_epoch=True):
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_params(self, param_name):
pass

def set_params(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
assert False
Expand Down Expand Up @@ -334,6 +341,7 @@ def preproc_obs(self, obs):
if isinstance(obs, dict):
obs = obs['obs']
obs = self.model.norm_obs(obs)

return obs

def cast_obs(self, obs):
Expand All @@ -348,7 +356,7 @@ def cast_obs(self, obs):

return obs

# todo: move to common utils
# TODO: move to common utils
def obs_to_tensors(self, obs):
obs_is_dict = isinstance(obs, dict)
if obs_is_dict:
Expand All @@ -359,6 +367,7 @@ def obs_to_tensors(self, obs):
upd_obs = self.cast_obs(obs)
if not obs_is_dict or 'obs' not in obs:
upd_obs = {'obs' : upd_obs}

return upd_obs

def _obs_to_tensors_internal(self, obs):
Expand All @@ -368,18 +377,19 @@ def _obs_to_tensors_internal(self, obs):
upd_obs[key] = self._obs_to_tensors_internal(value)
else:
upd_obs = self.cast_obs(obs)

return upd_obs

def preprocess_actions(self, actions):
if not self.is_tensor_obses:
actions = actions.cpu().numpy()

return actions

def env_step(self, actions):
actions = self.preprocess_actions(actions)
obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space)

self.step += self.num_actors
if self.is_tensor_obses:
return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos
else:
Expand Down Expand Up @@ -415,7 +425,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act
def clear_stats(self):
self.game_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
self.mean_rewards = self.last_mean_rewards = -1000000000
self.algo_observer.after_clear_stats()

def play_steps(self, random_exploration = False):
Expand Down Expand Up @@ -505,10 +515,9 @@ def train_epoch(self):
def train(self):
self.init_tensors()
self.algo_observer.after_init(self)
self.last_mean_rewards = -100500
total_time = 0
# rep_count = 0
self.frame = 0

self.obs = self.env_reset()

while True:
Expand Down
75 changes: 66 additions & 9 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self, base_name, params):
self.rnn_states = None
self.name = base_name

# TODO: do we still need it?
self.ppo = config.get('ppo', True)
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)
Expand Down Expand Up @@ -229,22 +230,26 @@ def __init__(self, base_name, params):
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation

self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors

assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)
self.mini_epochs_num = self.config['mini_epochs']

self.num_minibatches = self.batch_size // self.minibatch_size
assert(self.batch_size % self.minibatch_size == 0)

self.mini_epochs_num = self.config['mini_epochs']

self.mixed_precision = self.config.get('mixed_precision', False)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)

self.last_lr = self.config['learning_rate']
self.frame = 0
self.update_time = 0
self.mean_rewards = self.last_mean_rewards = -100500
self.mean_rewards = self.last_mean_rewards = -1000000000
self.play_time = 0
self.epoch_num = 0
self.curr_frames = 0
Expand Down Expand Up @@ -588,10 +593,11 @@ def train_central_value(self):
def get_full_state_weights(self):
state = self.get_weights()
state['epoch'] = self.epoch_num
state['frame'] = self.frame
state['optimizer'] = self.optimizer.state_dict()

if self.has_central_value:
state['assymetric_vf_nets'] = self.central_value_net.state_dict()
state['frame'] = self.frame

# This is actually the best reward ever achieved. last_mean_rewards is perhaps not the best variable name
# We save it to the checkpoint to prevent overriding the "best ever" checkpoint upon experiment restart
Expand All @@ -603,19 +609,22 @@ def get_full_state_weights(self):

return state

def set_full_state_weights(self, weights):
def set_full_state_weights(self, weights, set_epoch=True):

self.set_weights(weights)
self.epoch_num = weights['epoch'] # frames as well?
if set_epoch:
self.epoch_num = weights['epoch']
self.frame = weights['frame']

if self.has_central_value:
self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])

self.optimizer.load_state_dict(weights['optimizer'])
self.frame = weights.get('frame', 0)
self.last_mean_rewards = weights.get('last_mean_rewards', -100500)

env_state = weights.get('env_state', None)
self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000)

if self.vec_env is not None:
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def get_weights(self):
Expand Down Expand Up @@ -651,6 +660,55 @@ def set_weights(self, weights):
self.model.load_state_dict(weights['model'])
self.set_stats_weights(weights)

def get_params(self, param_name):
if param_name in [
"grad_norm",
"critic_coef",
"bounds_loss_coef",
"entropy_coef",
"kl_threshold",
"gamma",
"tau",
"mini_epochs_num",
"e_clip",
]:
return getattr(self, param_name)
elif param_name == "learning_rate":
return self.last_lr
else:
raise NotImplementedError(f"Can't get param {param_name}")

def set_params(self, param_name, param_value):
if param_name in [
"grad_norm",
"critic_coef",
"bounds_loss_coef",
"entropy_coef",
"gamma",
"tau",
"mini_epochs_num",
"e_clip",
]:
setattr(self, param_name, param_value)
elif param_name == "learning_rate":
if self.global_rank == 0:
if self.is_adaptive_lr:
raise NotImplementedError("Can't directly mutate LR on this schedule")
else:
self.learning_rate = param_value

for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate
elif param_name == "kl_threshold":
if self.global_rank == 0:
if self.is_adaptive_lr:
self.kl_threshold = param_value
self.scheduler.kl_threshold = param_value
else:
raise NotImplementedError("Can't directly mutate kl threshold")
else:
raise NotImplementedError(f"No param found for {param_value}")

def _preproc_obs(self, obs_batch):
if type(obs_batch) is dict:
obs_batch = copy.copy(obs_batch)
Expand Down Expand Up @@ -912,7 +970,6 @@ def prepare_dataset(self, batch_dict):
values = self.value_mean_std(values)
returns = self.value_mean_std(returns)
self.value_mean_std.eval()


advantages = torch.sum(advantages, axis=1)

Expand Down
12 changes: 11 additions & 1 deletion rl_games/interfaces/base_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from abc import abstractmethod, abstractproperty


class BaseAlgorithm(ABC):
def __init__(self, base_name, config):
pass
Expand All @@ -26,7 +27,7 @@ def get_full_state_weights(self):
pass

@abstractmethod
def set_full_state_weights(self, weights):
def set_full_state_weights(self, weights, set_epoch):
pass

@abstractmethod
Expand All @@ -37,5 +38,14 @@ def get_weights(self):
def set_weights(self, weights):
pass

# Get algo training parameters
@abstractmethod
def get_params(self, param_name):
pass

# Set algo training parameters
@abstractmethod
def set_params(self, param_name, param_value):
pass


0 comments on commit 289e32f

Please sign in to comment.