Skip to content

Commit

Permalink
Improved naming for fast pbt.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 19, 2023
1 parent efb2035 commit 2514904
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
13 changes: 10 additions & 3 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, base_name, params):
'action_dim': self.env_info["action_space"].shape[0],
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'normalize_input' : self.normalize_input,
'normalize_input': self.normalize_input,
}
self.model = self.network.build(net_config)
Expand Down Expand Up @@ -224,6 +223,7 @@ def set_weights(self, weights):
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])

def get_full_state_weights(self):
breakpoint()
print("Loading full weights")
state = self.get_weights()

Expand All @@ -246,14 +246,21 @@ def set_full_state_weights(self, weights, set_epoch=True):
self.critic_optimizer.load_state_dict(weights['critic_optimizer'])
self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer'])

self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000)

if self.vec_env is not None:
env_state = weights.get('env_state', None)
self.vec_env.set_env_state(env_state)

def restore(self, fn, set_epoch=True):
print("SAC restore")
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

def get_params(self, param_name):
def get_param(self, param_name):
pass

def set_params(self, param_name, param_value):
def set_param(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def set_weights(self, weights):
self.model.load_state_dict(weights['model'])
self.set_stats_weights(weights)

def get_params(self, param_name):
def get_param(self, param_name):
if param_name in [
"grad_norm",
"critic_coef",
Expand All @@ -678,7 +678,7 @@ def get_params(self, param_name):
else:
raise NotImplementedError(f"Can't get param {param_name}")

def set_params(self, param_name, param_value):
def set_param(self, param_name, param_value):
if param_name in [
"grad_norm",
"critic_coef",
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,9 @@ def run(self):
cur_rewards_done = cur_rewards/done_count
cur_steps_done = cur_steps/done_count
if print_game_res:
print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4} w: {game_res}')
print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}')
else:
print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4f}')
print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}')

sum_game_res += game_res
if batch_size//self.num_agents == 1 or games_played >= n_games:
Expand Down
4 changes: 2 additions & 2 deletions rl_games/interfaces/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def set_weights(self, weights):

# Get algo training parameters
@abstractmethod
def get_params(self, param_name):
def get_param(self, param_name):
pass

# Set algo training parameters
@abstractmethod
def set_params(self, param_name, param_value):
def set_param(self, param_name, param_value):
pass


5 changes: 1 addition & 4 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def __init__(self, algo_observer=None):
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

#breakpoint()

def reset(self):
pass

Expand Down Expand Up @@ -120,7 +118,7 @@ def run_train(self, args):
agent.train()

def run_play(self, args):
print('Started to play')
breakpoint()
player = self.create_player()
_restore(player, args)
_override_sigma(player, args)
Expand All @@ -137,7 +135,6 @@ def run(self, args):
print('Started to train')
self.run_train(args)
elif args['play']:
print('Started to play')
self.run_play(args)
else:
print('Started to train2')
Expand Down

0 comments on commit 2514904

Please sign in to comment.