Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 18, 2023
1 parent 990b478 commit 7cb5fce
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 8 deletions.
10 changes: 8 additions & 2 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, base_name, params):
self.num_steps_per_episode = config.get("num_steps_per_episode", 1)
self.normalize_input = config.get("normalize_input", False)

self.save_freq = config.get('save_frequency', 0)

self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach

print(self.batch_size, self.num_actors, self.num_agents)
Expand Down Expand Up @@ -236,7 +238,11 @@ def set_weights(self, weights):
def set_full_state_weights(self, weights):
self.set_weights(weights)

self.step = weights['step']
for key in weights:
print("Set full state weights keys:")
print(key)

self.step = weights['steps']
self.actor_optimizer.load_state_dict(weights['actor_optimizer'])
self.critic_optimizer.load_state_dict(weights['critic_optimizer'])
self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer'])
Expand Down Expand Up @@ -560,7 +566,7 @@ def train(self):
should_exit = False

if self.save_freq > 0:
if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if (self.epoch_num % self.save_freq) == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after:
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def train(self):
checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])

if self.save_freq > 0:
if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards):
if (epoch_num % self.save_freq) == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:
Expand Down Expand Up @@ -1301,7 +1301,7 @@ def train(self):
checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])

if self.save_freq > 0:
if (epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if (epoch_num % self.save_freq) == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,9 @@ def run(self):
cur_rewards_done = cur_rewards/done_count
cur_steps_done = cur_steps/done_count
if print_game_res:
print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4} w: {game_res}')
print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}')
else:
print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4f}')
print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}')

sum_game_res += game_res
if batch_size//self.num_agents == 1 or games_played >= n_games:
Expand Down
2 changes: 1 addition & 1 deletion rl_games/configs/mujoco/sac_ant_envpool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ params:
max_epochs: 10000
num_steps_per_episode: 8
save_best_after: 500
save_frequency: 10000
save_frequency: 1000
gamma: 0.99
init_alpha: 1
alpha_lr: 5e-3
Expand Down
1 change: 1 addition & 0 deletions rl_games/envs/envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def flatten_dict(obs):
res = np.column_stack(res)
return res


class Envpool(IVecEnv):
def __init__(self, config_name, num_actors, **kwargs):
import envpool
Expand Down
2 changes: 2 additions & 0 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, algo_observer=None):

self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver()
torch.backends.cudnn.benchmark = True

### it didnot help for lots for openai gym envs anyway :(
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)
Expand All @@ -57,6 +58,7 @@ def reset(self):
pass

def load_config(self, params):
print("Loading config")
self.seed = params.get('seed', None)
if self.seed is None:
self.seed = int(time.time())
Expand Down
3 changes: 2 additions & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
os.makedirs("nn", exist_ok=True)
os.makedirs("runs", exist_ok=True)

#breakpoint()

args = vars(ap.parse_args())
config_name = args['file']

Expand All @@ -50,7 +52,6 @@
except yaml.YAMLError as exc:
print(exc)

#rank = int(os.getenv("LOCAL_RANK", "0"))
global_rank = int(os.getenv("RANK", "0"))
if args["track"] and global_rank == 0:
import wandb
Expand Down

0 comments on commit 7cb5fce

Please sign in to comment.